|
14 | 14 | import os |
15 | 15 | from copy import deepcopy |
16 | 16 | from dataclasses import dataclass |
17 | | -from typing import Dict, Iterator, Any |
| 17 | +from typing import Any, Dict, Iterator |
18 | 18 | from unittest.mock import ANY, Mock |
19 | 19 |
|
20 | 20 | import pytest |
@@ -575,7 +575,7 @@ def test_fit_loop_reset(tmp_path): |
575 | 575 |
|
576 | 576 | fit_loop.reset() |
577 | 577 | epoch_loop.reset() |
578 | | - |
| 578 | + |
579 | 579 | # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 |
580 | 580 | assert fit_loop.restarting |
581 | 581 | assert fit_loop.epoch_progress.total.ready == 1 |
@@ -629,27 +629,28 @@ def compare_state_dicts(dict1, dict2): |
629 | 629 | def compare_leaves(d1, d2): |
630 | 630 | result = {} |
631 | 631 | all_keys = set(d1.keys()).union(d2.keys()) |
632 | | - |
| 632 | + |
633 | 633 | for key in all_keys: |
634 | 634 | val1 = d1.get(key, None) |
635 | 635 | val2 = d2.get(key, None) |
636 | | - |
| 636 | + |
637 | 637 | if isinstance(val1, dict) and isinstance(val2, dict): |
638 | 638 | res = compare_leaves(val1, val2) |
639 | 639 | if res: |
640 | 640 | result[key] = res |
641 | 641 | elif isinstance(val1, dict) or isinstance(val2, dict): |
642 | 642 | raise ValueError("dicts have different leaves") |
643 | | - elif type(val1) == torch.Tensor and type(val2) == torch.Tensor: |
| 643 | + elif isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): |
644 | 644 | diff = torch.norm(val1 - val2) |
645 | 645 | if diff > 1e-8: |
646 | 646 | result[key] = f"{diff} > 1e-8" |
647 | | - elif type(val1) == float and type(val2) == float: |
| 647 | + elif isinstance(val1, float) and isinstance(val2, float): |
648 | 648 | if abs(val1 - val2) > 1e-8: |
649 | 649 | result[key] = f"{val1} != {val2}" |
650 | 650 | elif val1 != val2: |
651 | 651 | result[key] = f"{val1} != {val2}" |
652 | 652 | return result |
| 653 | + |
653 | 654 | return compare_leaves(dict1, dict2) |
654 | 655 |
|
655 | 656 |
|
@@ -718,7 +719,7 @@ def test_restart_parity(tmp_path): |
718 | 719 | trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) |
719 | 720 | loss_v1 = model.last_loss |
720 | 721 |
|
721 | | - assert(abs(loss - loss_v1) < 1e-8) |
| 722 | + assert abs(loss - loss_v1) < 1e-8 |
722 | 723 |
|
723 | 724 | end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True) |
724 | 725 | end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True) |
@@ -783,7 +784,7 @@ def test_restart_parity_with_val(tmp_path): |
783 | 784 | trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) |
784 | 785 | loss_v1 = model.last_loss |
785 | 786 |
|
786 | | - assert(abs(loss - loss_v1) < 1e-8) |
| 787 | + assert abs(loss - loss_v1) < 1e-8 |
787 | 788 |
|
788 | 789 | end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True) |
789 | 790 | end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True) |
|
0 commit comments