Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .ci/scripts/test_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ test_model() {
bash examples/models/llava/install_requirements.sh
STRICT="--no-strict"
fi
if [[ "$MODEL_NAME" == "llama3_2_vision_encoder" ]]; then
# Install requirements for llama vision
if [[ "$MODEL_NAME" == "llama3_2_vision_encoder" || "$MODEL_NAME" == "llama3_2_text_decoder" ]]; then
# Install requirements for llama vision.
bash examples/models/llama3_2_vision/install_requirements.sh
fi
# python3 -m examples.portable.scripts.export --model_name="llama2" should works too
Expand Down
44 changes: 25 additions & 19 deletions examples/models/llama3_2_vision/text_decoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-unsafe

import json
import os
from typing import Any, Dict

import torch
Expand Down Expand Up @@ -52,10 +53,15 @@ def __init__(self, **kwargs):
self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.verbose = kwargs.get("verbose", False)
self.args = kwargs.get("args", None)
self.dtype = None
self.use_checkpoint = False

ckpt_dir = get_default_model_resource_dir(__file__)
# Single checkpoint file.
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
if os.path.isfile(checkpoint_path):
self.use_checkpoint = True

# Sharded checkpoint.
checkpoint_dir = kwargs.get("checkpoint_dir", None)
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
Expand All @@ -74,18 +80,17 @@ def __init__(self, **kwargs):
raise NotImplementedError(
"Sharded checkpoint not yet supported for Llama3_2Decoder."
)
else:
elif self.use_checkpoint:
checkpoint = torch.load(
checkpoint_path, map_location=device, weights_only=False, mmap=True
)
checkpoint = llama3_vision_meta_to_tune(checkpoint)
checkpoint = to_decoder_checkpoint(checkpoint)
checkpoint = llama3_vision_meta_to_tune(checkpoint)
checkpoint = to_decoder_checkpoint(checkpoint)
self.dtype = get_checkpoint_dtype(checkpoint)

with open(params_path, "r") as f:
params = json.loads(f.read())

# Find dtype from checkpoint. (skip for now)
self.dtype = get_checkpoint_dtype(checkpoint)

# Load model.
# Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
# i.e. the model isn't fully initialized or something.
Expand All @@ -108,19 +113,20 @@ def __init__(self, **kwargs):

# Quantize. (skip for now)

# Load checkpoint.
missing, unexpected = self.model_.load_state_dict(
checkpoint,
strict=False,
assign=True,
)
if kwargs.get("verbose", False):
print("============= missing keys ================")
print(missing)
print("============= /missing ================")
print("============= unexpected keys ================")
print(unexpected)
print("============= /unexpected ================")
if self.use_checkpoint:
# Load checkpoint.
missing, unexpected = self.model_.load_state_dict(
checkpoint,
strict=False,
assign=True,
)
if kwargs.get("verbose", False):
print("============= missing keys ================")
print(missing)
print("============= /missing ================")
print("============= unexpected keys ================")
print(unexpected)
print("============= /unexpected ================")

# Prune the output layer if output_prune_map is provided.
output_prune_map = None
Expand Down
Loading