Skip to content

Commit 5d4d0ca

Browse files
lucylqfacebook-github-bot
authored andcommitted
CI failures with dtype-override (pytorch#1919)
Summary: Pull Request resolved: pytorch#1919 language_llama failure: https://www.internalfb.com/sandcastle/workflow/3071454945868675212 llama stories failure: https://www.internalfb.com/intern/testinfra/diagnostics/5066549797810715.562950083098286.1707509873/ Reviewed By: larryliu0820, angelayi Differential Revision: D53625845 fbshipit-source-id: 59b4af4cd5329c7c8baa199831ebf1d21540cb81
1 parent 83d4e52 commit 5d4d0ca

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

examples/models/llama2/model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,12 @@ def __init__(self, **kwargs):
499499
device = "cpu"
500500
# flake8: noqa: TOR102
501501
checkpoint = torch.load(checkpoint_path, map_location=device)
502+
if kwargs.get("fairseq2", False):
503+
print("Using fairseq2 checkpoint")
504+
checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
505+
if "model" in checkpoint:
506+
# NB: some checkpoint contains a "model" field, which is the actual weights dict
507+
checkpoint = checkpoint["model"]
502508
# get checkpoint dtype
503509
self.dtype = None
504510
if len(checkpoint) > 0:
@@ -513,12 +519,6 @@ def __init__(self, **kwargs):
513519
print(
514520
f"Mixed dtype model. Dtype of {first.key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
515521
)
516-
if kwargs.get("fairseq2", False):
517-
print("Using fairseq2 checkpoint")
518-
checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
519-
if "model" in checkpoint:
520-
# NB: some checkpoint contains a "model" field, which is the actual weights dict
521-
checkpoint = checkpoint["model"]
522522
with open(params_path, "r") as f:
523523
params = json.loads(f.read())
524524
max_seq_len = 128

0 commit comments

Comments
 (0)