Skip to content

Commit 6b1feba

Browse files
author
igor
committed
Small update
1 parent 49a6d9e commit 6b1feba

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

batchflow/models/torch/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,9 +1817,8 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill,
18171817
from safetensors.torch import load_file
18181818
state_dict = load_file(file, device=device)
18191819

1820-
with torch.no_grad():
1821-
self.model = Network(inputs=None, config=self.config, device=self.device)
1822-
self.model.load_state_dict(state_dict)
1820+
self.initialize()
1821+
self.model.load_state_dict(state_dict)
18231822

18241823
self.model_to_device()
18251824

0 commit comments

Comments
 (0)