Skip to content

Commit 49a6d9e

Browse files
author
igor
committed
Update load model for safetensors
1 parent 515dbc2 commit 49a6d9e

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

batchflow/models/torch/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1816,7 +1816,10 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill,
18161816
if file.endswith(".safetensors"):
18171817
from safetensors.torch import load_file
18181818
state_dict = load_file(file, device=device)
1819-
self.model.load_state_dict(state_dict)
1819+
1820+
with torch.no_grad():
1821+
self.model = Network(inputs=None, config=self.config, device=self.device)
1822+
self.model.load_state_dict(state_dict)
18201823

18211824
self.model_to_device()
18221825

0 commit comments

Comments
 (0)