@@ -267,7 +267,7 @@ def test_train_single_table(tmp_path: Path):
267267 )
268268 X_gen , y_gen = x_gen_tensor .numpy (), y_gen_tensor .numpy ()
269269
270- with open ("tests/integration/data/single_table/assertion_data/syntetic_trans_data .json" , "r" ) as f :
270+ with open ("tests/integration/data/single_table/assertion_data/syntetic_data .json" , "r" ) as f :
271271 expected_results = json .load (f )
272272
273273 # Assert the synthetic samples are within the expected values
@@ -281,14 +281,41 @@ def test_train_single_table(tmp_path: Path):
281281
282282@pytest .mark .integration_test ()
283283def test_train_multi_table (tmp_path : Path ):
284+ # Setup
285+ set_all_random_seeds (seed = 133742 , use_deterministic_torch_algos = True , disable_torch_benchmarking = True )
286+
287+ # Act
284288 os .makedirs (tmp_path / "models" )
285289 configs = {"clustering" : CLUSTERING_CONFIG , "diffusion" : DIFFUSION_CONFIG , "classifier" : CLASSIFIER_CONFIG }
286290
287291 tables , relation_order , dataset_meta = load_multi_table ("tests/integration/data/multi_table/" )
288292 tables , all_group_lengths_prob_dicts = clava_clustering (tables , relation_order , tmp_path , configs )
289293 models = clava_training (tables , relation_order , tmp_path , configs , device = "cpu" )
290294
291- assert models
295+ # Assert
296+ with open (tmp_path / "models" / "None_trans_ckpt.pkl" , "rb" ) as f :
297+ table_info = pickle .load (f )["table_info" ]
298+
299+ sample_size = 5
300+ key = (None , "trans" )
301+ x_gen_tensor , y_gen_tensor = models [key ]["diffusion" ].sample_all (
302+ sample_size ,
303+ DIFFUSION_CONFIG ["batch_size" ],
304+ table_info [key ]["empirical_class_dist" ].float (),
305+ ddim = False ,
306+ )
307+ X_gen , y_gen = x_gen_tensor .numpy (), y_gen_tensor .numpy ()
308+
309+ with open ("tests/integration/data/multi_table/assertion_data/syntetic_data.json" , "r" ) as f :
310+ expected_results = json .load (f )
311+
312+ # Assert the synthetic samples are within the expected values
313+ # For X_gen, we are checking if the standard deviations of each row
314+ # are within a pre-defined range with some percentage of tolerance
315+ assert np .allclose (X_gen .std (axis = 1 , ddof = 0 ), expected_results ["X_gen_std" ], rtol = 0.05 , atol = 0 )
316+ assert all (y_gen == expected_results ["y_gen" ])
317+
318+ unset_all_random_seeds ()
292319
293320
294321@pytest .mark .integration_test ()
@@ -308,11 +335,6 @@ def test_clustering_reload(tmp_path: Path):
308335 account_original_df_as_float = tables ["account" ]["original_df" ].astype (float )
309336 assert account_df_no_clustering .equals (account_original_df_as_float )
310337
311- print (tables ["account" ]["df" ]["account_trans_cluster" ].tolist ())
312- print (tables ["trans" ]["df" ]["account_trans_cluster" ].tolist ())
313-
314- assert account_df_no_clustering == account_original_df_as_float
315-
316338 if _is_apple_silicon ():
317339 # TODO: Figure out if there is a good way of testing the clustering results
318340 # on multiple platforms. https://app.clickup.com/t/868f43wp0
0 commit comments