77# pyre-unsafe
88
99import json
10+ import os
1011from typing import Any , Dict
1112
1213import torch
@@ -52,10 +53,14 @@ def __init__(self, **kwargs):
5253 self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
5354 self .verbose = kwargs .get ("verbose" , False )
5455 self .args = kwargs .get ("args" , None )
56+ self .use_checkpoint = False
5557
5658 ckpt_dir = get_default_model_resource_dir (__file__ )
5759 # Single checkpoint file.
5860 checkpoint_path = kwargs .get ("checkpoint" , ckpt_dir / "demo_rand_params.pth" )
61+ if os .path .isfile (checkpoint_path ):
62+ self .use_checkpoint = True
63+
5964 # Sharded checkpoint.
6065 checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
6166 params_path = kwargs .get ("params" , ckpt_dir / "demo_config.json" )
@@ -74,18 +79,17 @@ def __init__(self, **kwargs):
7479 raise NotImplementedError (
7580 "Sharded checkpoint not yet supported for Llama3_2Decoder."
7681 )
77- else :
82+ elif self . use_checkpoint :
7883 checkpoint = torch .load (
7984 checkpoint_path , map_location = device , weights_only = False , mmap = True
8085 )
81- checkpoint = llama3_vision_meta_to_tune (checkpoint )
82- checkpoint = to_decoder_checkpoint (checkpoint )
86+ checkpoint = llama3_vision_meta_to_tune (checkpoint )
87+ checkpoint = to_decoder_checkpoint (checkpoint )
88+ self .dtype = get_checkpoint_dtype (checkpoint )
89+
8390 with open (params_path , "r" ) as f :
8491 params = json .loads (f .read ())
8592
86- # Find dtype from checkpoint. (skip for now)
87- self .dtype = get_checkpoint_dtype (checkpoint )
88-
8993 # Load model.
9094 # Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
9195 # i.e. the model isn't fully initialized or something.
@@ -108,19 +112,20 @@ def __init__(self, **kwargs):
108112
109113 # Quantize. (skip for now)
110114
111- # Load checkpoint.
112- missing , unexpected = self .model_ .load_state_dict (
113- checkpoint ,
114- strict = False ,
115- assign = True ,
116- )
117- if kwargs .get ("verbose" , False ):
118- print ("============= missing keys ================" )
119- print (missing )
120- print ("============= /missing ================" )
121- print ("============= unexpected keys ================" )
122- print (unexpected )
123- print ("============= /unexpected ================" )
115+ if self .use_checkpoint :
116+ # Load checkpoint.
117+ missing , unexpected = self .model_ .load_state_dict (
118+ checkpoint ,
119+ strict = False ,
120+ assign = True ,
121+ )
122+ if kwargs .get ("verbose" , False ):
123+ print ("============= missing keys ================" )
124+ print (missing )
125+ print ("============= /missing ================" )
126+ print ("============= unexpected keys ================" )
127+ print (unexpected )
128+ print ("============= /unexpected ================" )
124129
125130 # Prune the output layer if output_prune_map is provided.
126131 output_prune_map = None
0 commit comments