Skip to content

Commit 2efed33

Browse files
committed
add compute_output_shape
1 parent 4202554 commit 2efed33

File tree

4 files changed

+26
-30
lines changed

4 files changed

+26
-30
lines changed

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,17 +179,11 @@ def build(self, xz_shape, conditions_shape=None):
179179
out_shape = self.subnet.compute_output_shape(input_shape)
180180
else:
181181
# Multiple separate inputs
182-
main_input_shape = xz_shape
183182
time_shape = xz_shape[:-1] + (1,) # same batch/sequence dims, 1 feature
184-
185-
# Build subnet with multiple input shapes
186-
input_shape = [main_input_shape, time_shape]
187-
if conditions_shape is not None:
188-
input_shape.append(conditions_shape)
189-
190-
self.subnet.build(input_shape) # Pass list of shapes
191-
out_shape = self.subnet.compute_output_shape(input_shape)
192-
183+
self.subnet.build(input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape)
184+
out_shape = self.subnet.compute_output_shape(
185+
input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape
186+
)
193187
self.output_projector.build(out_shape)
194188

195189
# Choose coefficient according to [2] Section 3.3

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,16 +146,11 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
146146
out_shape = self.subnet.compute_output_shape(input_shape)
147147
else:
148148
# Multiple separate inputs
149-
main_input_shape = xz_shape
150149
time_shape = xz_shape[:-1] + (1,) # same batch/sequence dims, 1 feature
151-
152-
# Build subnet with multiple input shapes
153-
input_shape = [main_input_shape, time_shape]
154-
if conditions_shape is not None:
155-
input_shape.append(conditions_shape)
156-
157-
self.subnet.build(input_shape) # Pass list of shapes
158-
out_shape = self.subnet.compute_output_shape(input_shape)
150+
self.subnet.build(input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape)
151+
out_shape = self.subnet.compute_output_shape(
152+
input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape
153+
)
159154

160155
self.output_projector.build(out_shape)
161156

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,16 +140,11 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
140140
out_shape = self.subnet.compute_output_shape(input_shape)
141141
else:
142142
# Multiple separate inputs
143-
main_input_shape = xz_shape
144143
time_shape = xz_shape[:-1] + (1,) # same batch/sequence dims, 1 feature
145-
146-
# Build subnet with multiple input shapes
147-
input_shape = [main_input_shape, time_shape]
148-
if conditions_shape is not None:
149-
input_shape.append(conditions_shape)
150-
151-
self.subnet.build(input_shape) # Pass list of shapes
152-
out_shape = self.subnet.compute_output_shape(input_shape)
144+
self.subnet.build(input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape)
145+
out_shape = self.subnet.compute_output_shape(
146+
input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape
147+
)
153148

154149
self.output_projector.build(out_shape)
155150

tests/test_networks/conftest.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,25 @@ def call(self, x, t, conditions=None, training=False):
2828
con = concatenate_valid([x, t, conditions], axis=-1)
2929
return self.mlp(con)
3030

31-
def compute_output_shape(self, input_shape):
31+
def compute_output_shape(self, input_shape_x, input_shape_t, input_shape_conditions=None):
3232
concatenate_input_shapes = tuple(
33-
(input_shape[0][0], sum([shape[-1] for shape in input_shape if shape is not None]))
33+
(
34+
input_shape_x[0],
35+
input_shape_x[-1]
36+
+ input_shape_t[-1]
37+
+ (input_shape_conditions[-1] if input_shape_conditions is not None else 0),
38+
)
3439
)
3540
out = self.mlp.compute_output_shape(concatenate_input_shapes)
3641
return out
3742

43+
def build(self, input_shape_x, input_shape_t, input_shape_conditions=None):
44+
if self.built:
45+
return
46+
47+
input_shape = self.compute_output_shape(input_shape_x, input_shape_t, input_shape_conditions)
48+
self.mlp.build(input_shape)
49+
3850

3951
@pytest.fixture()
4052
def diffusion_model_edm_F_subnet_concatenate():

0 commit comments

Comments
 (0)