@@ -251,31 +251,24 @@ def test_train_single_table(tmp_path: Path):
251251
252252 # Act
253253 tables , relation_order , _ = load_multi_table ("tests/integration/data/single_table/" )
254- tables , models = clava_training (
255- tables ,
256- relation_order ,
257- tmp_path ,
258- configs ,
259- device = "cpu" ,
260- initial_state_file_path = "tests/integration/data/diffusion_initial_state.pth" ,
261- )
254+ tables , models = clava_training (tables , relation_order , tmp_path , configs , device = "cpu" )
262255
263256 # Assert
264- # with open(tmp_path / "models" / "None_trans_ckpt.pkl", "rb") as f:
265- # table_info = pickle.load(f)["table_info"]
257+ with open (tmp_path / "models" / "None_trans_ckpt.pkl" , "rb" ) as f :
258+ table_info = pickle .load (f )["table_info" ]
266259
267- # sample_size = 5
260+ sample_size = 5
268261 key = (None , "trans" )
269- # x_gen_tensor, y_gen_tensor = models[key]["diffusion"].sample_all(
270- # sample_size,
271- # DIFFUSION_CONFIG["batch_size"],
272- # table_info[key]["empirical_class_dist"].float(),
273- # ddim=False,
274- # )
275- # X_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy()
262+ x_gen_tensor , y_gen_tensor = models [key ]["diffusion" ].sample_all (
263+ sample_size ,
264+ DIFFUSION_CONFIG ["batch_size" ],
265+ table_info [key ]["empirical_class_dist" ].float (),
266+ ddim = False ,
267+ )
268+ X_gen , y_gen = x_gen_tensor .numpy (), y_gen_tensor .numpy ()
276269
277- # with open("tests/integration/data/single_table/assertion_data/syntetic_data.json", "r") as f:
278- # expected_results = json.load(f)
270+ with open ("tests/integration/data/single_table/assertion_data/syntetic_data.json" , "r" ) as f :
271+ expected_results = json .load (f )
279272
280273 model_data = dict (models [key ]["diffusion" ].named_parameters ())
281274
@@ -284,25 +277,25 @@ def test_train_single_table(tmp_path: Path):
284277 )
285278
286279 model_layers = list (model_data .keys ())
287- # expected_model_layers = list(expected_model_data.keys())
288- # if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
289- # if the first layer is equal with minimal tolerance, all others should be equal as well
290- assert all (
291- np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach (), atol = 0.08 )
292- for layer in model_layers
293- )
294-
295- # TODO: Figure out if there is a good way of testing the synthetic data results
296- # on multiple platforms. https://app.clickup.com/t/868f43wp0
297- # assert np.allclose(X_gen, expected_results["X_gen"])
298- # assert np.allclose(y_gen, expected_results["y_gen"])
280+ expected_model_layers = list (expected_model_data .keys ())
281+ if np .allclose (model_data [model_layers [0 ]].detach (), expected_model_data [expected_model_layers [0 ]].detach ()):
282+ # if the first layer is equal with minimal tolerance, all others should be equal as well
283+ assert all (
284+ np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach ()) for layer in model_layers
285+ )
286+
287+ # TODO: Figure out if there is a good way of testing the synthetic data results
288+ # on multiple platforms. https://app.clickup.com/t/868f43wp0
289+ assert np .allclose (X_gen , expected_results ["X_gen" ])
290+ assert np .allclose (y_gen , expected_results ["y_gen" ])
299291
300- # else:
301- # # Otherwise, set a tolerance that would work across platforms
302- # assert all(
303- # np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.1)
304- # for layer in model_layers
305- # )
292+ else :
293+ # Otherwise, set a tolerance that would work across platforms
294+ # TODO: Figure out a way to set a lower tolerance
295+ assert all (
296+ np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach (), atol = 0.1 )
297+ for layer in model_layers
298+ )
306299
307300 unset_all_random_seeds ()
308301
@@ -318,31 +311,24 @@ def test_train_multi_table(tmp_path: Path):
318311
319312 tables , relation_order , _ = load_multi_table ("tests/integration/data/multi_table/" )
320313 tables , _ = clava_clustering (tables , relation_order , tmp_path , configs )
321- models = clava_training (
322- tables ,
323- relation_order ,
324- tmp_path ,
325- configs ,
326- device = "cpu" ,
327- initial_state_file_path = "tests/integration/data/diffusion_initial_state.pth" ,
328- )
314+ models = clava_training (tables , relation_order , tmp_path , configs , device = "cpu" )
329315
330316 # Assert
331- # with open(tmp_path / "models" / "account_trans_ckpt.pkl", "rb") as f:
332- # table_info = pickle.load(f)["table_info"]
317+ with open (tmp_path / "models" / "account_trans_ckpt.pkl" , "rb" ) as f :
318+ table_info = pickle .load (f )["table_info" ]
333319
334- # sample_size = 5
320+ sample_size = 5
335321 key = ("account" , "trans" )
336- # x_gen_tensor, y_gen_tensor = models[1][key]["diffusion"].sample_all(
337- # sample_size,
338- # DIFFUSION_CONFIG["batch_size"],
339- # table_info[key]["empirical_class_dist"].float(),
340- # ddim=False,
341- # )
342- # X_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy()
322+ x_gen_tensor , y_gen_tensor = models [1 ][key ]["diffusion" ].sample_all (
323+ sample_size ,
324+ DIFFUSION_CONFIG ["batch_size" ],
325+ table_info [key ]["empirical_class_dist" ].float (),
326+ ddim = False ,
327+ )
328+ X_gen , y_gen = x_gen_tensor .numpy (), y_gen_tensor .numpy ()
343329
344- # with open("tests/integration/data/multi_table/assertion_data/syntetic_data.json", "r") as f:
345- # expected_results = json.load(f)
330+ with open ("tests/integration/data/multi_table/assertion_data/syntetic_data.json" , "r" ) as f :
331+ expected_results = json .load (f )
346332
347333 model_data = dict (models [1 ][key ]["diffusion" ].named_parameters ())
348334
@@ -351,26 +337,26 @@ def test_train_multi_table(tmp_path: Path):
351337 )
352338
353339 model_layers = list (model_data .keys ())
354- # expected_model_layers = list(expected_model_data.keys())
340+ expected_model_layers = list (expected_model_data .keys ())
355341
356- # if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
357- # if the first layer is equal with minimal tolerance, all others should be equal as well
358- assert all (
359- np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach (), atol = 0.08 )
360- for layer in model_layers
361- )
342+ if np .allclose (model_data [model_layers [0 ]].detach (), expected_model_data [expected_model_layers [0 ]].detach ()):
343+ # if the first layer is equal with minimal tolerance, all others should be equal as well
344+ assert all (
345+ np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach ()) for layer in model_layers
346+ )
362347
363- # # TODO: Figure out if there is a good way of testing the synthetic data results
364- # # on multiple platforms. https://app.clickup.com/t/868f43wp0
365- # assert np.allclose(X_gen, expected_results["X_gen"])
366- # assert np.allclose(y_gen, expected_results["y_gen"])
367-
368- # else:
369- # # Otherwise, set a tolerance that would work across platforms
370- # assert all(
371- # np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.1)
372- # for layer in model_layers
373- # )
348+ # TODO: Figure out if there is a good way of testing the synthetic data results
349+ # on multiple platforms. https://app.clickup.com/t/868f43wp0
350+ assert np .allclose (X_gen , expected_results ["X_gen" ])
351+ assert np .allclose (y_gen , expected_results ["y_gen" ])
352+
353+ else :
354+ # Otherwise, set a tolerance that would work across platforms
355+ # TODO: Figure out a way to set a lower tolerance
356+ assert all (
357+ np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach (), atol = 0.1 )
358+ for layer in model_layers
359+ )
374360
375361 unset_all_random_seeds ()
376362
0 commit comments