Skip to content

Commit e782ca5

Browse files
committed
Check devices matching during load random seed for device.
When dumping random state for one device but loading it from another, the format of random state differs so program will raise error, we need to check whether the device type is unchanged before loading it. PR: USTC-KnowledgeComputingLab/qmb#48 Signed-off-by: Hao Zhang <[email protected]>
2 parents 2869340 + 16f0021 commit e782ca5

File tree

4 files changed

+6
-8
lines changed

4 files changed

+6
-8
lines changed

qmb/common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
import tyro
1212
from .model_dict import model_dict, ModelProto, NetworkProto
13-
from .random_engine import load_random_engine_state
13+
from .random_engine import dump_random_engine_state, load_random_engine_state
1414

1515

1616
@dataclasses.dataclass
@@ -77,6 +77,7 @@ def save(self, data: typing.Any, step: int) -> None:
7777
"""
7878
Save data to checkpoint.
7979
"""
80+
data["random"] = {"host": torch.get_rng_state(), "device": dump_random_engine_state(self.device), "device_type": self.device.type}
8081
data_pth = self.folder() / "data.pth"
8182
local_data_pth = self.folder() / f"data.{step}.pth"
8283
torch.save(data, local_data_pth)
@@ -151,7 +152,10 @@ def main(self, *, model_param: typing.Any = None, network_param: typing.Any = No
151152
elif "random" in data:
152153
logging.info("Loading random seed from the checkpoint")
153154
torch.set_rng_state(data["random"]["host"])
154-
load_random_engine_state(data["random"]["device"], self.device)
155+
if data["random"]["device_type"] == self.device.type:
156+
load_random_engine_state(data["random"]["device"], self.device)
157+
else:
158+
logging.info("Skipping loading random engine state for device since the device type does not match")
155159
else:
156160
logging.info("Random seed not specified, using current seed: %d", torch.seed())
157161

qmb/imag.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from .subcommand_dict import subcommand_dict
1717
from .model_dict import ModelProto
1818
from .optimizer import initialize_optimizer, scale_learning_rate
19-
from .random_engine import dump_random_engine_state
2019

2120

2221
@dataclasses.dataclass
@@ -513,7 +512,6 @@ def closure() -> torch.Tensor:
513512
data["imag"]["global"] += 1
514513
data["network"] = network.state_dict()
515514
data["optimizer"] = optimizer.state_dict()
516-
data["random"] = {"host": torch.get_rng_state(), "device": dump_random_engine_state(self.common.device)}
517515
self.common.save(data, data["imag"]["global"])
518516
logging.info("Checkpoint successfully saved")
519517

qmb/rldiag.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from .model_dict import ModelProto
1515
from .optimizer import initialize_optimizer
1616
from .bitspack import pack_int
17-
from .random_engine import dump_random_engine_state
1817

1918

2019
def lanczos_energy(model: ModelProto, configs: torch.Tensor, step: int, threshold: float) -> tuple[float, torch.Tensor]:
@@ -215,7 +214,6 @@ def main(self) -> None:
215214
data["rldiag"]["local"] += 1
216215
data["network"] = network.state_dict()
217216
data["optimizer"] = optimizer.state_dict()
218-
data["random"] = {"host": torch.get_rng_state(), "device": dump_random_engine_state(self.common.device)}
219217
self.common.save(data, data["rldiag"]["global"])
220218
logging.info("Checkpoint successfully saved")
221219

qmb/vmc.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from .common import CommonConfig
1212
from .subcommand_dict import subcommand_dict
1313
from .optimizer import initialize_optimizer
14-
from .random_engine import dump_random_engine_state
1514

1615

1716
@dataclasses.dataclass
@@ -133,7 +132,6 @@ def closure() -> torch.Tensor:
133132
data["vmc"]["global"] += 1
134133
data["network"] = network.state_dict()
135134
data["optimizer"] = optimizer.state_dict()
136-
data["random"] = {"host": torch.get_rng_state(), "device": dump_random_engine_state(self.common.device)}
137135
self.common.save(data, data["vmc"]["global"])
138136
logging.info("Checkpoint successfully saved")
139137

0 commit comments

Comments
 (0)