@@ -327,9 +327,7 @@ def test_nnpe(random_data):
327327
328328 # Test basic case with global noise application
329329 ad = Adapter ().nnpe ("x1" , spike_scale = 1.0 , slab_scale = 1.0 , per_dimension = False , seed = 42 )
330- result_training = ad (random_data , stage = "training" )
331- result_validation = ad (random_data , stage = "validation" )
332- result_inference = ad (random_data , stage = "inference" )
330+ result_training = ad (random_data )
333331 result_inversed = ad (random_data , inverse = True )
334332 serialized = serialize (ad )
335333 deserialized = deserialize (serialized )
@@ -349,13 +347,11 @@ def test_nnpe(random_data):
349347
350348 # check that the validation and inference data as well as inversed results are unchanged
351349 for k , v in random_data .items ():
352- assert np .allclose (result_validation [k ], v )
353- assert np .allclose (result_inference [k ], v )
354350 assert np .allclose (result_inversed [k ], v )
355351
356352 # Test both scales and seed are None case (automatic scale determination) with dimensionwise noise application
357353 ad_auto = Adapter ().nnpe ("y1" , slab_scale = None , spike_scale = None , per_dimension = True , seed = None )
358- result_training_auto = ad_auto (random_data , stage = "training" )
354+ result_training_auto = ad_auto (random_data )
359355 assert not np .allclose (result_training_auto ["y1" ], random_data ["y1" ])
360356 for k , v in random_data .items ():
361357 if k == "y1" :
@@ -378,8 +374,8 @@ def test_nnpe(random_data):
378374 # Apply dimensionwise and global adapters with automatic slab_scale scale determination
379375 ad_partial_global = Adapter ().nnpe ("x" , spike_scale = 0 , slab_scale = None , per_dimension = False , seed = 42 )
380376 ad_partial_dim = Adapter ().nnpe ("x" , spike_scale = [0 , 1 ], slab_scale = None , per_dimension = True , seed = 42 )
381- res_dim = ad_partial_dim (var_data , stage = "training" )
382- res_glob = ad_partial_global (var_data , stage = "training" )
377+ res_dim = ad_partial_dim (var_data )
378+ res_glob = ad_partial_global (var_data )
383379
384380 # Compute standard deviations of noise per last axis dimension
385381 noise_dim = res_dim ["x" ] - var_data ["x" ]
0 commit comments