Skip to content

Commit 16f0021

Browse files
committed
Check devices matching during load random seed for device.
1 parent 2869340 commit 16f0021

File tree

4 files changed

+6
-8
lines changed

4 files changed

+6
-8
lines changed

qmb/common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
import tyro
1212
from .model_dict import model_dict, ModelProto, NetworkProto
13-
from .random_engine import load_random_engine_state
13+
from .random_engine import dump_random_engine_state, load_random_engine_state
1414

1515

1616
@dataclasses.dataclass
@@ -77,6 +77,7 @@ def save(self, data: typing.Any, step: int) -> None:
7777
"""
7878
Save data to checkpoint.
7979
"""
80+
data["random"] = {"host": torch.get_rng_state(), "device": dump_random_engine_state(self.device), "device_type": self.device.type}
8081
data_pth = self.folder() / "data.pth"
8182
local_data_pth = self.folder() / f"data.{step}.pth"
8283
torch.save(data, local_data_pth)
@@ -151,7 +152,10 @@ def main(self, *, model_param: typing.Any = None, network_param: typing.Any = No
151152
elif "random" in data:
152153
logging.info("Loading random seed from the checkpoint")
153154
torch.set_rng_state(data["random"]["host"])
154-
load_random_engine_state(data["random"]["device"], self.device)
155+
if data["random"]["device_type"] == self.device.type:
156+
load_random_engine_state(data["random"]["device"], self.device)
157+
else:
158+
logging.info("Skipping loading random engine state for device since the device type does not match")
155159
else:
156160
logging.info("Random seed not specified, using current seed: %d", torch.seed())
157161

qmb/imag.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from .subcommand_dict import subcommand_dict
1717
from .model_dict import ModelProto
1818
from .optimizer import initialize_optimizer, scale_learning_rate
19-
from .random_engine import dump_random_engine_state
2019

2120

2221
@dataclasses.dataclass
@@ -513,7 +512,6 @@ def closure() -> torch.Tensor:
513512
data["imag"]["global"] += 1
514513
data["network"] = network.state_dict()
515514
data["optimizer"] = optimizer.state_dict()
516-
data["random"] = {"host": torch.get_rng_state(), "device": dump_random_engine_state(self.common.device)}
517515
self.common.save(data, data["imag"]["global"])
518516
logging.info("Checkpoint successfully saved")
519517

qmb/rldiag.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from .model_dict import ModelProto
1515
from .optimizer import initialize_optimizer
1616
from .bitspack import pack_int
17-
from .random_engine import dump_random_engine_state
1817

1918

2019
def lanczos_energy(model: ModelProto, configs: torch.Tensor, step: int, threshold: float) -> tuple[float, torch.Tensor]:
@@ -215,7 +214,6 @@ def main(self) -> None:
215214
data["rldiag"]["local"] += 1
216215
data["network"] = network.state_dict()
217216
data["optimizer"] = optimizer.state_dict()
218-
data["random"] = {"host": torch.get_rng_state(), "device": dump_random_engine_state(self.common.device)}
219217
self.common.save(data, data["rldiag"]["global"])
220218
logging.info("Checkpoint successfully saved")
221219

qmb/vmc.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from .common import CommonConfig
1212
from .subcommand_dict import subcommand_dict
1313
from .optimizer import initialize_optimizer
14-
from .random_engine import dump_random_engine_state
1514

1615

1716
@dataclasses.dataclass
@@ -133,7 +132,6 @@ def closure() -> torch.Tensor:
133132
data["vmc"]["global"] += 1
134133
data["network"] = network.state_dict()
135134
data["optimizer"] = optimizer.state_dict()
136-
data["random"] = {"host": torch.get_rng_state(), "device": dump_random_engine_state(self.common.device)}
137135
self.common.save(data, data["vmc"]["global"])
138136
logging.info("Checkpoint successfully saved")
139137

0 commit comments

Comments
 (0)