Skip to content

Commit 1bd5d46

Browse files
committed
Loading the state dict in a different way
1 parent 63a632e commit 1bd5d46

File tree

3 files changed

+27
-13
lines changed

3 files changed

+27
-13
lines changed

src/midst_toolkit/models/clavaddpm/model.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def clava_clustering(tables, relation_order, save_dir, configs):
283283
return tables, all_group_lengths_prob_dicts
284284

285285

286-
def clava_training(tables, relation_order, save_dir, configs, device="cuda", initial_state_dict=None):
286+
def clava_training(tables, relation_order, save_dir, configs, device="cuda", initial_state_file_path=None):
287287
models = {}
288288
for parent, child in relation_order:
289289
print(f"Training {parent} -> {child} model from scratch")
@@ -298,7 +298,7 @@ def clava_training(tables, relation_order, save_dir, configs, device="cuda", ini
298298
child,
299299
configs,
300300
device,
301-
initial_state_dict,
301+
initial_state_file_path,
302302
)
303303

304304
models[(parent, child)] = result
@@ -324,7 +324,7 @@ def child_training(
324324
child_name: str,
325325
configs: dict[str, Any],
326326
device: str = "cuda",
327-
initial_state_dict: dict[str, Tensor] | None = None,
327+
initial_state_file_path: Path | None = None,
328328
) -> dict[str, Any]:
329329
if parent_name is None:
330330
y_col = "placeholder"
@@ -354,7 +354,7 @@ def child_training(
354354
configs["diffusion"]["lr"],
355355
configs["diffusion"]["weight_decay"],
356356
device=device,
357-
initial_state_dict=initial_state_dict,
357+
initial_state_file_path=initial_state_file_path,
358358
)
359359

360360
if parent_name is None:
@@ -398,7 +398,7 @@ def train_model(
398398
lr: float,
399399
weight_decay: float,
400400
device: str = "cuda",
401-
initial_state_dict: dict[str, Tensor] | None = None,
401+
initial_state_file_path: Path | None = None,
402402
) -> dict[str, Any]:
403403
T = Transformations(**T_dict)
404404
dataset, label_encoders, column_orders = make_dataset_from_df(
@@ -443,8 +443,14 @@ def train_model(
443443
)
444444
diffusion.to(device)
445445

446-
if initial_state_dict is not None:
447-
diffusion.load_state_dict(initial_state_dict)
446+
print("++++++++++++++++++++++++ BEFORE ++++++++++++++++++++++++++")
447+
print(diffusion.state_dict())
448+
449+
if initial_state_file_path is not None:
450+
diffusion.load_state_dict(torch.load(initial_state_file_path, weights_only=True))
451+
452+
print("++++++++++++++++++++++++ AFTER ++++++++++++++++++++++++++")
453+
print(diffusion.state_dict())
448454

449455
diffusion.train()
450456

16.4 MB
Binary file not shown.

tests/integration/models/clavaddpm/test_model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,16 @@ def test_train_single_table(tmp_path: Path):
248248

249249
os.makedirs(tmp_path / "models")
250250
configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG}
251-
initial_state_dict = pickle.loads(Path("tests/integration/data/diffusion_initial_state.pkl").read_bytes())
252251

253252
# Act
254253
tables, relation_order, _ = load_multi_table("tests/integration/data/single_table/")
255254
tables, models = clava_training(
256-
tables, relation_order, tmp_path, configs, device="cpu", initial_state_dict=initial_state_dict
255+
tables,
256+
relation_order,
257+
tmp_path,
258+
configs,
259+
device="cpu",
260+
initial_state_file_path="tests/integration/data/diffusion_initial_state.pth",
257261
)
258262

259263
# Assert
@@ -284,7 +288,7 @@ def test_train_single_table(tmp_path: Path):
284288
# if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
285289
# if the first layer is equal with minimal tolerance, all others should be equal as well
286290
assert all(
287-
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.1)
291+
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.05)
288292
for layer in model_layers
289293
)
290294

@@ -311,12 +315,16 @@ def test_train_multi_table(tmp_path: Path):
311315
# Act
312316
os.makedirs(tmp_path / "models")
313317
configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG, "classifier": CLASSIFIER_CONFIG}
314-
initial_state_dict = pickle.loads(Path("tests/integration/data/diffusion_initial_state.pkl").read_bytes())
315318

316319
tables, relation_order, _ = load_multi_table("tests/integration/data/multi_table/")
317320
tables, _ = clava_clustering(tables, relation_order, tmp_path, configs)
318321
models = clava_training(
319-
tables, relation_order, tmp_path, configs, device="cpu", initial_state_dict=initial_state_dict
322+
tables,
323+
relation_order,
324+
tmp_path,
325+
configs,
326+
device="cpu",
327+
initial_state_file_path="tests/integration/data/diffusion_initial_state.pth",
320328
)
321329

322330
# Assert
@@ -348,7 +356,7 @@ def test_train_multi_table(tmp_path: Path):
348356
# if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
349357
# if the first layer is equal with minimal tolerance, all others should be equal as well
350358
assert all(
351-
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.1)
359+
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.05)
352360
for layer in model_layers
353361
)
354362

0 commit comments

Comments
 (0)