Skip to content

Commit 697a2cc

Browse files
committed
Leave weights uninitialized for checkopint load fail
1 parent 62ccef1 commit 697a2cc

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

examples/models/llama/model.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,14 +236,21 @@ def __init__(self, **kwargs):
236236
eviction_batch_size=eviction_batch_size,
237237
)
238238

239-
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
240-
# Because we are using device="meta", tensors do not have memory associated with them
241-
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
242-
missing, unexpected = self.model_.load_state_dict(
243-
checkpoint,
244-
strict=False,
245-
assign=True,
246-
) # self.model_ = Transformer(gptconf)
239+
missing, unexpected = None, None
240+
try:
241+
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
242+
# Because we are using device="meta", tensors do not have memory associated with them
243+
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
244+
missing, unexpected = self.model_.load_state_dict(
245+
checkpoint,
246+
strict=False,
247+
assign=True,
248+
) # self.model_ = Transformer(gptconf)
249+
except RuntimeError as e:
250+
print(
251+
"Could not load checkpoint into mode, defaulting to random uninitialized weights."
252+
)
253+
print(f"Error: {e}")
247254

248255
if missing:
249256
missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")]

0 commit comments

Comments
 (0)