Skip to content

Commit 8c4792d

Browse files
committed
reverting
1 parent 636b271 commit 8c4792d

File tree

3 files changed

+63
-84
lines changed

3 files changed

+63
-84
lines changed

src/midst_toolkit/models/clavaddpm/model.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def clava_clustering(tables, relation_order, save_dir, configs):
283283
return tables, all_group_lengths_prob_dicts
284284

285285

286-
def clava_training(tables, relation_order, save_dir, configs, device="cuda", initial_state_file_path=None):
286+
def clava_training(tables, relation_order, save_dir, configs, device="cuda"):
287287
models = {}
288288
for parent, child in relation_order:
289289
print(f"Training {parent} -> {child} model from scratch")
@@ -298,7 +298,6 @@ def clava_training(tables, relation_order, save_dir, configs, device="cuda", ini
298298
child,
299299
configs,
300300
device,
301-
initial_state_file_path,
302301
)
303302

304303
models[(parent, child)] = result
@@ -354,7 +353,6 @@ def child_training(
354353
configs["diffusion"]["lr"],
355354
configs["diffusion"]["weight_decay"],
356355
device=device,
357-
initial_state_file_path=initial_state_file_path,
358356
)
359357

360358
if parent_name is None:
@@ -398,7 +396,6 @@ def train_model(
398396
lr: float,
399397
weight_decay: float,
400398
device: str = "cuda",
401-
initial_state_file_path: Path | None = None,
402399
) -> dict[str, Any]:
403400
T = Transformations(**T_dict)
404401
dataset, label_encoders, column_orders = make_dataset_from_df(
@@ -442,10 +439,6 @@ def train_model(
442439
device=torch.device(device),
443440
)
444441
diffusion.to(device)
445-
446-
if initial_state_file_path is not None:
447-
diffusion.load_state_dict(torch.load(initial_state_file_path, weights_only=True))
448-
449442
diffusion.train()
450443

451444
trainer = Trainer(
-16.4 MB
Binary file not shown.

tests/integration/models/clavaddpm/test_model.py

Lines changed: 62 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -251,31 +251,24 @@ def test_train_single_table(tmp_path: Path):
251251

252252
# Act
253253
tables, relation_order, _ = load_multi_table("tests/integration/data/single_table/")
254-
tables, models = clava_training(
255-
tables,
256-
relation_order,
257-
tmp_path,
258-
configs,
259-
device="cpu",
260-
initial_state_file_path="tests/integration/data/diffusion_initial_state.pth",
261-
)
254+
tables, models = clava_training(tables, relation_order, tmp_path, configs, device="cpu")
262255

263256
# Assert
264-
# with open(tmp_path / "models" / "None_trans_ckpt.pkl", "rb") as f:
265-
# table_info = pickle.load(f)["table_info"]
257+
with open(tmp_path / "models" / "None_trans_ckpt.pkl", "rb") as f:
258+
table_info = pickle.load(f)["table_info"]
266259

267-
# sample_size = 5
260+
sample_size = 5
268261
key = (None, "trans")
269-
# x_gen_tensor, y_gen_tensor = models[key]["diffusion"].sample_all(
270-
# sample_size,
271-
# DIFFUSION_CONFIG["batch_size"],
272-
# table_info[key]["empirical_class_dist"].float(),
273-
# ddim=False,
274-
# )
275-
# X_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy()
262+
x_gen_tensor, y_gen_tensor = models[key]["diffusion"].sample_all(
263+
sample_size,
264+
DIFFUSION_CONFIG["batch_size"],
265+
table_info[key]["empirical_class_dist"].float(),
266+
ddim=False,
267+
)
268+
X_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy()
276269

277-
# with open("tests/integration/data/single_table/assertion_data/syntetic_data.json", "r") as f:
278-
# expected_results = json.load(f)
270+
with open("tests/integration/data/single_table/assertion_data/syntetic_data.json", "r") as f:
271+
expected_results = json.load(f)
279272

280273
model_data = dict(models[key]["diffusion"].named_parameters())
281274

@@ -284,25 +277,25 @@ def test_train_single_table(tmp_path: Path):
284277
)
285278

286279
model_layers = list(model_data.keys())
287-
# expected_model_layers = list(expected_model_data.keys())
288-
# if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
289-
# if the first layer is equal with minimal tolerance, all others should be equal as well
290-
assert all(
291-
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.08)
292-
for layer in model_layers
293-
)
294-
295-
# TODO: Figure out if there is a good way of testing the synthetic data results
296-
# on multiple platforms. https://app.clickup.com/t/868f43wp0
297-
# assert np.allclose(X_gen, expected_results["X_gen"])
298-
# assert np.allclose(y_gen, expected_results["y_gen"])
280+
expected_model_layers = list(expected_model_data.keys())
281+
if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
282+
# if the first layer is equal with minimal tolerance, all others should be equal as well
283+
assert all(
284+
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach()) for layer in model_layers
285+
)
286+
287+
# TODO: Figure out if there is a good way of testing the synthetic data results
288+
# on multiple platforms. https://app.clickup.com/t/868f43wp0
289+
assert np.allclose(X_gen, expected_results["X_gen"])
290+
assert np.allclose(y_gen, expected_results["y_gen"])
299291

300-
# else:
301-
# # Otherwise, set a tolerance that would work across platforms
302-
# assert all(
303-
# np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.1)
304-
# for layer in model_layers
305-
# )
292+
else:
293+
# Otherwise, set a tolerance that would work across platforms
294+
# TODO: Figure out a way to set a lower tolerance
295+
assert all(
296+
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.1)
297+
for layer in model_layers
298+
)
306299

307300
unset_all_random_seeds()
308301

@@ -318,31 +311,24 @@ def test_train_multi_table(tmp_path: Path):
318311

319312
tables, relation_order, _ = load_multi_table("tests/integration/data/multi_table/")
320313
tables, _ = clava_clustering(tables, relation_order, tmp_path, configs)
321-
models = clava_training(
322-
tables,
323-
relation_order,
324-
tmp_path,
325-
configs,
326-
device="cpu",
327-
initial_state_file_path="tests/integration/data/diffusion_initial_state.pth",
328-
)
314+
models = clava_training(tables, relation_order, tmp_path, configs, device="cpu")
329315

330316
# Assert
331-
# with open(tmp_path / "models" / "account_trans_ckpt.pkl", "rb") as f:
332-
# table_info = pickle.load(f)["table_info"]
317+
with open(tmp_path / "models" / "account_trans_ckpt.pkl", "rb") as f:
318+
table_info = pickle.load(f)["table_info"]
333319

334-
# sample_size = 5
320+
sample_size = 5
335321
key = ("account", "trans")
336-
# x_gen_tensor, y_gen_tensor = models[1][key]["diffusion"].sample_all(
337-
# sample_size,
338-
# DIFFUSION_CONFIG["batch_size"],
339-
# table_info[key]["empirical_class_dist"].float(),
340-
# ddim=False,
341-
# )
342-
# X_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy()
322+
x_gen_tensor, y_gen_tensor = models[1][key]["diffusion"].sample_all(
323+
sample_size,
324+
DIFFUSION_CONFIG["batch_size"],
325+
table_info[key]["empirical_class_dist"].float(),
326+
ddim=False,
327+
)
328+
X_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy()
343329

344-
# with open("tests/integration/data/multi_table/assertion_data/syntetic_data.json", "r") as f:
345-
# expected_results = json.load(f)
330+
with open("tests/integration/data/multi_table/assertion_data/syntetic_data.json", "r") as f:
331+
expected_results = json.load(f)
346332

347333
model_data = dict(models[1][key]["diffusion"].named_parameters())
348334

@@ -351,26 +337,26 @@ def test_train_multi_table(tmp_path: Path):
351337
)
352338

353339
model_layers = list(model_data.keys())
354-
# expected_model_layers = list(expected_model_data.keys())
340+
expected_model_layers = list(expected_model_data.keys())
355341

356-
# if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
357-
# if the first layer is equal with minimal tolerance, all others should be equal as well
358-
assert all(
359-
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.08)
360-
for layer in model_layers
361-
)
342+
if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
343+
# if the first layer is equal with minimal tolerance, all others should be equal as well
344+
assert all(
345+
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach()) for layer in model_layers
346+
)
362347

363-
# # TODO: Figure out if there is a good way of testing the synthetic data results
364-
# # on multiple platforms. https://app.clickup.com/t/868f43wp0
365-
# assert np.allclose(X_gen, expected_results["X_gen"])
366-
# assert np.allclose(y_gen, expected_results["y_gen"])
367-
368-
# else:
369-
# # Otherwise, set a tolerance that would work across platforms
370-
# assert all(
371-
# np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.1)
372-
# for layer in model_layers
373-
# )
348+
# TODO: Figure out if there is a good way of testing the synthetic data results
349+
# on multiple platforms. https://app.clickup.com/t/868f43wp0
350+
assert np.allclose(X_gen, expected_results["X_gen"])
351+
assert np.allclose(y_gen, expected_results["y_gen"])
352+
353+
else:
354+
# Otherwise, set a tolerance that would work across platforms
355+
# TODO: Figure out a way to set a lower tolerance
356+
assert all(
357+
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.1)
358+
for layer in model_layers
359+
)
374360

375361
unset_all_random_seeds()
376362

0 commit comments

Comments
 (0)