diff --git a/.ci/scripts/test_model.sh b/.ci/scripts/test_model.sh index 1f42880b8f9..0727eecf770 100755 --- a/.ci/scripts/test_model.sh +++ b/.ci/scripts/test_model.sh @@ -87,8 +87,8 @@ test_model() { bash examples/models/llava/install_requirements.sh STRICT="--no-strict" fi - if [[ "$MODEL_NAME" == "llama3_2_vision_encoder" ]]; then - # Install requirements for llama vision + if [[ "$MODEL_NAME" == "llama3_2_vision_encoder" || "$MODEL_NAME" == "llama3_2_text_decoder" ]]; then + # Install requirements for llama vision. bash examples/models/llama3_2_vision/install_requirements.sh fi # python3 -m examples.portable.scripts.export --model_name="llama2" should works too diff --git a/examples/models/llama3_2_vision/text_decoder/model.py b/examples/models/llama3_2_vision/text_decoder/model.py index 73d79bb08e0..943f9da730e 100644 --- a/examples/models/llama3_2_vision/text_decoder/model.py +++ b/examples/models/llama3_2_vision/text_decoder/model.py @@ -7,6 +7,7 @@ # pyre-unsafe import json +import os from typing import Any, Dict import torch @@ -52,10 +53,15 @@ def __init__(self, **kwargs): self.use_kv_cache = kwargs.get("use_kv_cache", False) self.verbose = kwargs.get("verbose", False) self.args = kwargs.get("args", None) + self.dtype = None + self.use_checkpoint = False ckpt_dir = get_default_model_resource_dir(__file__) # Single checkpoint file. checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth") + if os.path.isfile(checkpoint_path): + self.use_checkpoint = True + # Sharded checkpoint. checkpoint_dir = kwargs.get("checkpoint_dir", None) params_path = kwargs.get("params", ckpt_dir / "demo_config.json") @@ -74,18 +80,17 @@ def __init__(self, **kwargs): raise NotImplementedError( "Sharded checkpoint not yet supported for Llama3_2Decoder." ) - else: + elif self.use_checkpoint: checkpoint = torch.load( checkpoint_path, map_location=device, weights_only=False, mmap=True ) - checkpoint = llama3_vision_meta_to_tune(checkpoint) - checkpoint = to_decoder_checkpoint(checkpoint) + checkpoint = llama3_vision_meta_to_tune(checkpoint) + checkpoint = to_decoder_checkpoint(checkpoint) + self.dtype = get_checkpoint_dtype(checkpoint) + with open(params_path, "r") as f: params = json.loads(f.read()) - # Find dtype from checkpoint. (skip for now) - self.dtype = get_checkpoint_dtype(checkpoint) - # Load model. # Cannot use "with torch.device("meta"):" because it causes some exceptions during export, # i.e. the model isn't fully initialized or something. @@ -108,19 +113,20 @@ def __init__(self, **kwargs): # Quantize. (skip for now) - # Load checkpoint. - missing, unexpected = self.model_.load_state_dict( - checkpoint, - strict=False, - assign=True, - ) - if kwargs.get("verbose", False): - print("============= missing keys ================") - print(missing) - print("============= /missing ================") - print("============= unexpected keys ================") - print(unexpected) - print("============= /unexpected ================") + if self.use_checkpoint: + # Load checkpoint. + missing, unexpected = self.model_.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + if kwargs.get("verbose", False): + print("============= missing keys ================") + print(missing) + print("============= /missing ================") + print("============= unexpected keys ================") + print(unexpected) + print("============= /unexpected ================") # Prune the output layer if output_prune_map is provided. output_prune_map = None