@@ -38,14 +38,13 @@ def __init__(self, **kwargs):
3838 resource_dir = get_default_model_resource_dir (__file__ )
3939
4040 # Use single checkpoint file.
41- checkpoint_path = kwargs .get (
42- "checkpoint" , resource_dir / "demo_rand_params.pth"
43- )
44- params_path = kwargs .get ("params" , resource_dir / "demo_config.json" )
45-
41+ checkpoint_path = kwargs .get ("checkpoint" , None )
4642 # Check if checkpoint_dir was provided for a sharded checkpoint.
4743 checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
4844
45+ # Params file.
46+ params_path = kwargs .get ("params" , None )
47+
4948 self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
5049 self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
5150 self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
@@ -66,6 +65,7 @@ def __init__(self, **kwargs):
6665 # flake8: noqa: TOR102
6766 cps = []
6867 # Load sharded checkpoint.
68+ checkpoint = {}
6969 if checkpoint_dir is not None :
7070 # Load multiple checkpoint; ignore the single path.
7171 checkpoint_path = None
@@ -93,7 +93,7 @@ def __init__(self, **kwargs):
9393 # Do not duplicate layers shared between each checkpoint.
9494 checkpoint [key ] = cps [0 ][key ]
9595 # Load single checkpoint.
96- else :
96+ elif checkpoint_path :
9797 checkpoint = torch .load (checkpoint_path , map_location = device , mmap = True )
9898
9999 # If given checkpoint is fairseq, convert to llama checkpoint.
@@ -123,10 +123,17 @@ def __init__(self, **kwargs):
123123 )
124124
125125 # Get checkpoint dtype.
126- self .dtype = get_checkpoint_dtype (checkpoint )
126+ if checkpoint :
127+ self .dtype = get_checkpoint_dtype (checkpoint )
128+ else :
129+ self .dtype = None
130+
131+ # Get optional params.
132+ params = {}
133+ if params_path :
134+ with open (params_path , "r" ) as f :
135+ params = json .loads (f .read ())
127136
128- with open (params_path , "r" ) as f :
129- params = json .loads (f .read ())
130137 output_prune_map = None
131138 if self .output_prune_map_path is not None :
132139 with open (self .output_prune_map_path , "r" ) as f :
@@ -241,16 +248,21 @@ def __init__(self, **kwargs):
241248 # assign=True: load params/buffers by assignment instead of performing an in-place copy.
242249 # Because we are using device="meta", tensors do not have memory associated with them
243250 # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
244- missing , unexpected = self .model_ .load_state_dict (
245- checkpoint ,
246- strict = False ,
247- assign = True ,
248- ) # self.model_ = Transformer(gptconf)
251+ if checkpoint :
252+ missing , unexpected = self .model_ .load_state_dict (
253+ checkpoint ,
254+ strict = False ,
255+ assign = True ,
256+ ) # self.model_ = Transformer(gptconf)
257+ else :
258+ print (
259+ "Checkpoint not provided, defaulting to random uninitialized weights."
260+ )
261+ self .model_ .to_empty (device = "cpu" )
249262 except RuntimeError as e :
250263 print (
251- "Could not load checkpoint into mode, defaulting to random uninitialized weights."
264+ f "Could not load checkpoint into mode and will default to random uninitialized weights due to error: { e } ."
252265 )
253- print (f"Error: { e } " )
254266 # Need to provide concrete (empty) values for meta-initialized tensors for quantization.
255267 self .model_ .to_empty (device = "cpu" )
256268
0 commit comments