Skip to content

Commit 690d33b

Browse files
committed
fix ConcatenateMLP
1 parent 2efed33 commit 690d33b

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tests/test_networks/conftest.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,21 @@ def compute_output_shape(self, input_shape_x, input_shape_t, input_shape_conditi
3737
+ (input_shape_conditions[-1] if input_shape_conditions is not None else 0),
3838
)
3939
)
40-
out = self.mlp.compute_output_shape(concatenate_input_shapes)
41-
return out
40+
return self.mlp.compute_output_shape(concatenate_input_shapes)
4241

4342
def build(self, input_shape_x, input_shape_t, input_shape_conditions=None):
4443
if self.built:
4544
return
4645

47-
input_shape = self.compute_output_shape(input_shape_x, input_shape_t, input_shape_conditions)
48-
self.mlp.build(input_shape)
46+
concatenate_input_shapes = tuple(
47+
(
48+
input_shape_x[0],
49+
input_shape_x[-1]
50+
+ input_shape_t[-1]
51+
+ (input_shape_conditions[-1] if input_shape_conditions is not None else 0),
52+
)
53+
)
54+
self.mlp.build(concatenate_input_shapes)
4955

5056

5157
@pytest.fixture()

0 commit comments

Comments
 (0)