@@ -10,7 +10,7 @@ def test_forward_standardization_training():
1010 layer = Standardization (momentum = 0.0 ) # no EMA for test stability
1111 layer .build (random_input .shape )
1212
13- out = layer (random_input , stage = "training" , forward = True )
13+ out = layer (random_input , stage = "training" )
1414
1515 moving_mean = keras .ops .convert_to_numpy (layer .moving_mean )
1616 moving_std = keras .ops .convert_to_numpy (layer .moving_std )
@@ -32,7 +32,7 @@ def test_inverse_standardization_ldj():
3232 layer .build (random_input .shape )
3333
3434 _ = layer (random_input , stage = "training" , forward = True ) # trigger moment update
35- inv_x , ldj = layer (random_input , stage = "inference" , forward = False )
35+ inv_x , ldj = layer (random_input , stage = "inference" , forward = False , log_det_jac = True )
3636
3737 assert inv_x .shape == random_input .shape
3838 assert ldj .shape == random_input .shape [:- 1 ]
@@ -43,7 +43,7 @@ def test_consistency_forward_inverse():
4343 layer = Standardization (momentum = 0.0 )
4444 layer .build ((5 ,))
4545 standardized = layer (random_input , stage = "training" , forward = True )
46- recovered , _ = layer (standardized , stage = "inference" , forward = False )
46+ recovered = layer (standardized , stage = "inference" , forward = False )
4747
4848 random_input = keras .ops .convert_to_numpy (random_input )
4949 recovered = keras .ops .convert_to_numpy (recovered )
0 commit comments