77# pyre-unsafe
88
99import json
10+ import os
1011from typing import Any , Dict
1112
1213import torch
@@ -52,10 +53,15 @@ def __init__(self, **kwargs):
5253 self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
5354 self .verbose = kwargs .get ("verbose" , False )
5455 self .args = kwargs .get ("args" , None )
56+ self .dtype = None
57+ self .use_checkpoint = False
5558
5659 ckpt_dir = get_default_model_resource_dir (__file__ )
5760 # Single checkpoint file.
5861 checkpoint_path = kwargs .get ("checkpoint" , ckpt_dir / "demo_rand_params.pth" )
62+ if os .path .isfile (checkpoint_path ):
63+ self .use_checkpoint = True
64+
5965 # Sharded checkpoint.
6066 checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
6167 params_path = kwargs .get ("params" , ckpt_dir / "demo_config.json" )
@@ -74,18 +80,17 @@ def __init__(self, **kwargs):
7480 raise NotImplementedError (
7581 "Sharded checkpoint not yet supported for Llama3_2Decoder."
7682 )
77- else :
83+ elif self . use_checkpoint :
7884 checkpoint = torch .load (
7985 checkpoint_path , map_location = device , weights_only = False , mmap = True
8086 )
81- checkpoint = llama3_vision_meta_to_tune (checkpoint )
82- checkpoint = to_decoder_checkpoint (checkpoint )
87+ checkpoint = llama3_vision_meta_to_tune (checkpoint )
88+ checkpoint = to_decoder_checkpoint (checkpoint )
89+ self .dtype = get_checkpoint_dtype (checkpoint )
90+
8391 with open (params_path , "r" ) as f :
8492 params = json .loads (f .read ())
8593
86- # Find dtype from checkpoint. (skip for now)
87- self .dtype = get_checkpoint_dtype (checkpoint )
88-
8994 # Load model.
9095 # Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
9196 # i.e. the model isn't fully initialized or something.
@@ -108,19 +113,20 @@ def __init__(self, **kwargs):
108113
109114 # Quantize. (skip for now)
110115
111- # Load checkpoint.
112- missing , unexpected = self .model_ .load_state_dict (
113- checkpoint ,
114- strict = False ,
115- assign = True ,
116- )
117- if kwargs .get ("verbose" , False ):
118- print ("============= missing keys ================" )
119- print (missing )
120- print ("============= /missing ================" )
121- print ("============= unexpected keys ================" )
122- print (unexpected )
123- print ("============= /unexpected ================" )
116+ if self .use_checkpoint :
117+ # Load checkpoint.
118+ missing , unexpected = self .model_ .load_state_dict (
119+ checkpoint ,
120+ strict = False ,
121+ assign = True ,
122+ )
123+ if kwargs .get ("verbose" , False ):
124+ print ("============= missing keys ================" )
125+ print (missing )
126+ print ("============= /missing ================" )
127+ print ("============= unexpected keys ================" )
128+ print (unexpected )
129+ print ("============= /unexpected ================" )
124130
125131 # Prune the output layer if output_prune_map is provided.
126132 output_prune_map = None
0 commit comments