Skip to content

Commit 9fe12a8

Browse files
committed
fix input shape
1 parent ef1f47c commit 9fe12a8

File tree

4 files changed

+59
-14
lines changed

4 files changed

+59
-14
lines changed

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,28 @@ def build(self, xz_shape, conditions_shape=None):
169169
input_shape = list(xz_shape)
170170

171171
if self._concatenate_subnet_input:
172-
# time vector
172+
# construct time vector
173173
input_shape[-1] += 1
174-
175174
if conditions_shape is not None:
176175
input_shape[-1] += conditions_shape[-1]
176+
input_shape = tuple(input_shape)
177+
178+
self.subnet.build(input_shape)
179+
out_shape = self.subnet.compute_output_shape(input_shape)
180+
else:
181+
# Multiple separate inputs
182+
main_input_shape = xz_shape
183+
time_shape = xz_shape[:-1] + (1,) # same batch/sequence dims, 1 feature
177184

178-
input_shape = tuple(input_shape)
185+
# Build subnet with multiple input shapes
186+
input_shape = [main_input_shape, time_shape]
187+
if conditions_shape is not None:
188+
input_shape.append(conditions_shape)
179189

180-
self.subnet.build(input_shape)
190+
self.subnet.build(input_shape) # Pass list of shapes
191+
out_shape = self.subnet.compute_output_shape(input_shape)
181192

182-
input_shape = self.subnet.compute_output_shape(input_shape)
183-
self.output_projector.build(input_shape)
193+
self.output_projector.build(out_shape)
184194

185195
# Choose coefficient according to [2] Section 3.3
186196
self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1])

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,23 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
140140
input_shape[-1] += 1
141141
if conditions_shape is not None:
142142
input_shape[-1] += conditions_shape[-1]
143+
input_shape = tuple(input_shape)
143144

144-
input_shape = tuple(input_shape)
145+
self.subnet.build(input_shape)
146+
out_shape = self.subnet.compute_output_shape(input_shape)
147+
else:
148+
# Multiple separate inputs
149+
main_input_shape = xz_shape
150+
time_shape = xz_shape[:-1] + (1,) # same batch/sequence dims, 1 feature
151+
152+
# Build subnet with multiple input shapes
153+
input_shape = [main_input_shape, time_shape]
154+
if conditions_shape is not None:
155+
input_shape.append(conditions_shape)
156+
157+
self.subnet.build(input_shape) # Pass list of shapes
158+
out_shape = self.subnet.compute_output_shape(input_shape)
145159

146-
self.subnet.build(input_shape)
147-
out_shape = self.subnet.compute_output_shape(input_shape)
148160
self.output_projector.build(out_shape)
149161

150162
def get_config(self):

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,28 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
130130

131131
input_shape = list(xz_shape)
132132
if self._concatenate_subnet_input:
133-
# account for concatenating the time and conditions
133+
# construct time vector
134134
input_shape[-1] += 1
135135
if conditions_shape is not None:
136136
input_shape[-1] += conditions_shape[-1]
137-
input_shape = tuple(input_shape)
137+
input_shape = tuple(input_shape)
138138

139-
self.subnet.build(input_shape)
140-
input_shape = self.subnet.compute_output_shape(input_shape)
141-
self.output_projector.build(input_shape)
139+
self.subnet.build(input_shape)
140+
out_shape = self.subnet.compute_output_shape(input_shape)
141+
else:
142+
# Multiple separate inputs
143+
main_input_shape = xz_shape
144+
time_shape = xz_shape[:-1] + (1,) # same batch/sequence dims, 1 feature
145+
146+
# Build subnet with multiple input shapes
147+
input_shape = [main_input_shape, time_shape]
148+
if conditions_shape is not None:
149+
input_shape.append(conditions_shape)
150+
151+
self.subnet.build(input_shape) # Pass list of shapes
152+
out_shape = self.subnet.compute_output_shape(input_shape)
153+
154+
self.output_projector.build(out_shape)
142155

143156
@classmethod
144157
def from_config(cls, config, custom_objects=None):

tests/test_networks/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ def call(self, x, t, conditions=None, training=False):
2727
con = concatenate_valid([x, t, conditions], axis=-1)
2828
return self.mlp(con)
2929

30+
def compute_output_shape(self, input_shape):
31+
concatenate_input_shapes = tuple(
32+
(input_shape[0][0], sum([shape[-1] for shape in input_shape if shape is not None]))
33+
)
34+
out = self.mlp.compute_output_shape(concatenate_input_shapes)
35+
return out
36+
3037

3138
@pytest.fixture()
3239
def diffusion_model_edm_F_subnet_concatenate():
@@ -198,9 +205,12 @@ def typical_point_inference_network_subnet():
198205
"affine_coupling_flow",
199206
"spline_coupling_flow",
200207
"flow_matching",
208+
pytest.param("flow_matching_subnet_concatenate"),
201209
"free_form_flow",
202210
"consistency_model",
211+
pytest.param("consistency_model_subnet_concatenate"),
203212
pytest.param("diffusion_model_edm_F"),
213+
pytest.param("diffusion_model_edm_F_subnet_concatenate"),
204214
pytest.param("diffusion_model_edm_noise", marks=pytest.mark.slow),
205215
pytest.param("diffusion_model_cosine_velocity", marks=pytest.mark.slow),
206216
pytest.param("diffusion_model_cosine_F", marks=pytest.mark.slow),

0 commit comments

Comments
 (0)