Skip to content

Commit b87b6c9

Browse files
committed
Remove checkpoint file requirement
1 parent b0d91a8 commit b87b6c9

File tree

1 file changed

+24
-19
lines changed
  • examples/models/llama3_2_vision/text_decoder

1 file changed

+24
-19
lines changed

examples/models/llama3_2_vision/text_decoder/model.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-unsafe
88

99
import json
10+
import os
1011
from typing import Any, Dict
1112

1213
import torch
@@ -52,10 +53,14 @@ def __init__(self, **kwargs):
5253
self.use_kv_cache = kwargs.get("use_kv_cache", False)
5354
self.verbose = kwargs.get("verbose", False)
5455
self.args = kwargs.get("args", None)
56+
self.use_checkpoint = False
5557

5658
ckpt_dir = get_default_model_resource_dir(__file__)
5759
# Single checkpoint file.
5860
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
61+
if os.path.isfile(checkpoint_path):
62+
self.use_checkpoint = True
63+
5964
# Sharded checkpoint.
6065
checkpoint_dir = kwargs.get("checkpoint_dir", None)
6166
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
@@ -74,18 +79,17 @@ def __init__(self, **kwargs):
7479
raise NotImplementedError(
7580
"Sharded checkpoint not yet supported for Llama3_2Decoder."
7681
)
77-
else:
82+
elif self.use_checkpoint:
7883
checkpoint = torch.load(
7984
checkpoint_path, map_location=device, weights_only=False, mmap=True
8085
)
81-
checkpoint = llama3_vision_meta_to_tune(checkpoint)
82-
checkpoint = to_decoder_checkpoint(checkpoint)
86+
checkpoint = llama3_vision_meta_to_tune(checkpoint)
87+
checkpoint = to_decoder_checkpoint(checkpoint)
88+
self.dtype = get_checkpoint_dtype(checkpoint)
89+
8390
with open(params_path, "r") as f:
8491
params = json.loads(f.read())
8592

86-
# Find dtype from checkpoint. (skip for now)
87-
self.dtype = get_checkpoint_dtype(checkpoint)
88-
8993
# Load model.
9094
# Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
9195
# i.e. the model isn't fully initialized or something.
@@ -108,19 +112,20 @@ def __init__(self, **kwargs):
108112

109113
# Quantize. (skip for now)
110114

111-
# Load checkpoint.
112-
missing, unexpected = self.model_.load_state_dict(
113-
checkpoint,
114-
strict=False,
115-
assign=True,
116-
)
117-
if kwargs.get("verbose", False):
118-
print("============= missing keys ================")
119-
print(missing)
120-
print("============= /missing ================")
121-
print("============= unexpected keys ================")
122-
print(unexpected)
123-
print("============= /unexpected ================")
115+
if self.use_checkpoint:
116+
# Load checkpoint.
117+
missing, unexpected = self.model_.load_state_dict(
118+
checkpoint,
119+
strict=False,
120+
assign=True,
121+
)
122+
if kwargs.get("verbose", False):
123+
print("============= missing keys ================")
124+
print(missing)
125+
print("============= /missing ================")
126+
print("============= unexpected keys ================")
127+
print(unexpected)
128+
print("============= /unexpected ================")
124129

125130
# Prune the output layer if output_prune_map is provided.
126131
output_prune_map = None

0 commit comments

Comments
 (0)