Skip to content

Commit 378efef

Browse files
committed
Removing debug asserts, adding asserts for multi table generation
1 parent cc0ad5c commit 378efef

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

tests/integration/data/single_table/assertion_data/syntetic_trans_data.json renamed to tests/integration/data/multi_table/assertion_data/syntetic_data.json

File renamed without changes.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"X_gen_std": [1270.027488, 497.444234, 1669.368549, 2383.459776, 1170.001261],
3+
"y_gen": [9, 47, 15, 6, 90]
4+
}

tests/integration/models/clavaddpm/test_model.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def test_train_single_table(tmp_path: Path):
267267
)
268268
X_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy()
269269

270-
with open("tests/integration/data/single_table/assertion_data/syntetic_trans_data.json", "r") as f:
270+
with open("tests/integration/data/single_table/assertion_data/syntetic_data.json", "r") as f:
271271
expected_results = json.load(f)
272272

273273
# Assert the synthetic samples are within the expected values
@@ -281,14 +281,41 @@ def test_train_single_table(tmp_path: Path):
281281

282282
@pytest.mark.integration_test()
283283
def test_train_multi_table(tmp_path: Path):
284+
# Setup
285+
set_all_random_seeds(seed=133742, use_deterministic_torch_algos=True, disable_torch_benchmarking=True)
286+
287+
# Act
284288
os.makedirs(tmp_path / "models")
285289
configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG, "classifier": CLASSIFIER_CONFIG}
286290

287291
tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/")
288292
tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, configs)
289293
models = clava_training(tables, relation_order, tmp_path, configs, device="cpu")
290294

291-
assert models
295+
# Assert
296+
with open(tmp_path / "models" / "None_trans_ckpt.pkl", "rb") as f:
297+
table_info = pickle.load(f)["table_info"]
298+
299+
sample_size = 5
300+
key = (None, "trans")
301+
x_gen_tensor, y_gen_tensor = models[key]["diffusion"].sample_all(
302+
sample_size,
303+
DIFFUSION_CONFIG["batch_size"],
304+
table_info[key]["empirical_class_dist"].float(),
305+
ddim=False,
306+
)
307+
X_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy()
308+
309+
with open("tests/integration/data/multi_table/assertion_data/syntetic_data.json", "r") as f:
310+
expected_results = json.load(f)
311+
312+
# Assert the synthetic samples are within the expected values
313+
# For X_gen, we are checking if the standard deviations of each row
314+
# are within a pre-defined range with some percentage of tolerance
315+
assert np.allclose(X_gen.std(axis=1, ddof=0), expected_results["X_gen_std"], rtol=0.05, atol=0)
316+
assert all(y_gen == expected_results["y_gen"])
317+
318+
unset_all_random_seeds()
292319

293320

294321
@pytest.mark.integration_test()
@@ -308,11 +335,6 @@ def test_clustering_reload(tmp_path: Path):
308335
account_original_df_as_float = tables["account"]["original_df"].astype(float)
309336
assert account_df_no_clustering.equals(account_original_df_as_float)
310337

311-
print(tables["account"]["df"]["account_trans_cluster"].tolist())
312-
print(tables["trans"]["df"]["account_trans_cluster"].tolist())
313-
314-
assert account_df_no_clustering == account_original_df_as_float
315-
316338
if _is_apple_silicon():
317339
# TODO: Figure out if there is a good way of testing the clustering results
318340
# on multiple platforms. https://app.clickup.com/t/868f43wp0

0 commit comments

Comments
 (0)