Skip to content

Commit f819250

Browse files
committed
change build parameters to match Keras convention
The convention is to use parameter name with a `_shape` suffix.
1 parent bece7cf commit f819250

File tree

4 files changed

+13
-17
lines changed

4 files changed

+13
-17
lines changed

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ def build(self, xz_shape, conditions_shape=None):
180180
else:
181181
# Multiple separate inputs
182182
time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature
183-
self.subnet.build(input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape)
183+
self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape)
184184
out_shape = self.subnet.compute_output_shape(
185-
input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape
185+
x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape
186186
)
187187
self.output_projector.build(out_shape)
188188

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
147147
else:
148148
# Multiple separate inputs
149149
time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature
150-
self.subnet.build(input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape)
150+
self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape)
151151
out_shape = self.subnet.compute_output_shape(
152-
input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape
152+
x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape
153153
)
154154

155155
self.output_projector.build(out_shape)

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
141141
else:
142142
# Multiple separate inputs
143143
time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature
144-
self.subnet.build(input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape)
144+
self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape)
145145
out_shape = self.subnet.compute_output_shape(
146-
input_shape_x=xz_shape, input_shape_t=time_shape, input_shape_conditions=conditions_shape
146+
x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape
147147
)
148148

149149
self.output_projector.build(out_shape)

tests/test_networks/conftest.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def diffusion_model_edm_F():
1818
)
1919

2020

21-
@serializable("bayesflow.networks")
21+
@serializable("test", disable_module_check=True)
2222
class ConcatenateMLP(Sequential):
2323
def __init__(
2424
self,
@@ -33,27 +33,23 @@ def call(self, x, t, conditions=None, training=False):
3333
con = concatenate_valid([x, t, conditions], axis=-1)
3434
return self.mlp(con)
3535

36-
def compute_output_shape(self, input_shape_x, input_shape_t, input_shape_conditions=None):
36+
def compute_output_shape(self, x_shape, t_shape, conditions_shape=None):
3737
concatenate_input_shapes = tuple(
3838
(
39-
input_shape_x[0],
40-
input_shape_x[-1]
41-
+ input_shape_t[-1]
42-
+ (input_shape_conditions[-1] if input_shape_conditions is not None else 0),
39+
x_shape[0],
40+
x_shape[-1] + t_shape[-1] + (conditions_shape[-1] if conditions_shape is not None else 0),
4341
)
4442
)
4543
return self.mlp.compute_output_shape(concatenate_input_shapes)
4644

47-
def build(self, input_shape_x, input_shape_t, input_shape_conditions=None):
45+
def build(self, x_shape, t_shape, conditions_shape=None):
4846
if self.built:
4947
return
5048

5149
concatenate_input_shapes = tuple(
5250
(
53-
input_shape_x[0],
54-
input_shape_x[-1]
55-
+ input_shape_t[-1]
56-
+ (input_shape_conditions[-1] if input_shape_conditions is not None else 0),
51+
x_shape[0],
52+
x_shape[-1] + t_shape[-1] + (conditions_shape[-1] if conditions_shape is not None else 0),
5753
)
5854
)
5955
self.mlp.build(concatenate_input_shapes)

0 commit comments

Comments
 (0)