Skip to content

Commit 0012dcb

Browse files
committed
Fix type checks in compare state dicts
1 parent e8bd2d7 commit 0012dcb

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

tests/tests_pytorch/loops/test_loops.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515
from copy import deepcopy
1616
from dataclasses import dataclass
17-
from typing import Dict, Iterator, Any
17+
from typing import Any, Dict, Iterator
1818
from unittest.mock import ANY, Mock
1919

2020
import pytest
@@ -575,7 +575,7 @@ def test_fit_loop_reset(tmp_path):
575575

576576
fit_loop.reset()
577577
epoch_loop.reset()
578-
578+
579579
# resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
580580
assert fit_loop.restarting
581581
assert fit_loop.epoch_progress.total.ready == 1
@@ -629,27 +629,28 @@ def compare_state_dicts(dict1, dict2):
629629
def compare_leaves(d1, d2):
630630
result = {}
631631
all_keys = set(d1.keys()).union(d2.keys())
632-
632+
633633
for key in all_keys:
634634
val1 = d1.get(key, None)
635635
val2 = d2.get(key, None)
636-
636+
637637
if isinstance(val1, dict) and isinstance(val2, dict):
638638
res = compare_leaves(val1, val2)
639639
if res:
640640
result[key] = res
641641
elif isinstance(val1, dict) or isinstance(val2, dict):
642642
raise ValueError("dicts have different leaves")
643-
elif type(val1) == torch.Tensor and type(val2) == torch.Tensor:
643+
elif isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor):
644644
diff = torch.norm(val1 - val2)
645645
if diff > 1e-8:
646646
result[key] = f"{diff} > 1e-8"
647-
elif type(val1) == float and type(val2) == float:
647+
elif isinstance(val1, float) and isinstance(val2, float):
648648
if abs(val1 - val2) > 1e-8:
649649
result[key] = f"{val1} != {val2}"
650650
elif val1 != val2:
651651
result[key] = f"{val1} != {val2}"
652652
return result
653+
653654
return compare_leaves(dict1, dict2)
654655

655656

@@ -718,7 +719,7 @@ def test_restart_parity(tmp_path):
718719
trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt"))
719720
loss_v1 = model.last_loss
720721

721-
assert(abs(loss - loss_v1) < 1e-8)
722+
assert abs(loss - loss_v1) < 1e-8
722723

723724
end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True)
724725
end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True)
@@ -783,7 +784,7 @@ def test_restart_parity_with_val(tmp_path):
783784
trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt"))
784785
loss_v1 = model.last_loss
785786

786-
assert(abs(loss - loss_v1) < 1e-8)
787+
assert abs(loss - loss_v1) < 1e-8
787788

788789
end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True)
789790
end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True)

0 commit comments

Comments
 (0)