Skip to content

Commit 45ab9ea

Browse files
committed
Fix tests
1 parent 7aeb9cb commit 45ab9ea

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/test_networks/test_standardization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_forward_standardization_training():
1010
layer = Standardization(momentum=0.0) # no EMA for test stability
1111
layer.build(random_input.shape)
1212

13-
out = layer(random_input, stage="training", forward=True)
13+
out = layer(random_input, stage="training")
1414

1515
moving_mean = keras.ops.convert_to_numpy(layer.moving_mean)
1616
moving_std = keras.ops.convert_to_numpy(layer.moving_std)
@@ -32,7 +32,7 @@ def test_inverse_standardization_ldj():
3232
layer.build(random_input.shape)
3333

3434
_ = layer(random_input, stage="training", forward=True) # trigger moment update
35-
inv_x, ldj = layer(random_input, stage="inference", forward=False)
35+
inv_x, ldj = layer(random_input, stage="inference", forward=False, log_det_jac=True)
3636

3737
assert inv_x.shape == random_input.shape
3838
assert ldj.shape == random_input.shape[:-1]
@@ -43,7 +43,7 @@ def test_consistency_forward_inverse():
4343
layer = Standardization(momentum=0.0)
4444
layer.build((5,))
4545
standardized = layer(random_input, stage="training", forward=True)
46-
recovered, _ = layer(standardized, stage="inference", forward=False)
46+
recovered = layer(standardized, stage="inference", forward=False)
4747

4848
random_input = keras.ops.convert_to_numpy(random_input)
4949
recovered = keras.ops.convert_to_numpy(recovered)

0 commit comments

Comments
 (0)