@@ -38,14 +38,13 @@ def __init__(self, **kwargs):
3838 resource_dir = get_default_model_resource_dir (__file__ )
3939
4040 # Use single checkpoint file.
41- checkpoint_path = kwargs .get (
42- "checkpoint" , resource_dir / "demo_rand_params.pth"
43- )
44- params_path = kwargs .get ("params" , resource_dir / "demo_config.json" )
45-
41+ checkpoint_path = kwargs .get ("checkpoint" , None )
4642 # Check if checkpoint_dir was provided for a sharded checkpoint.
4743 checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
4844
45+ # Params file.
46+ params_path = kwargs .get ("params" , None )
47+
4948 self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
5049 self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
5150 self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
@@ -66,6 +65,7 @@ def __init__(self, **kwargs):
6665 # flake8: noqa: TOR102
6766 cps = []
6867 # Load sharded checkpoint.
68+ checkpoint = {}
6969 if checkpoint_dir is not None :
7070 # Load multiple checkpoint; ignore the single path.
7171 checkpoint_path = None
@@ -93,7 +93,7 @@ def __init__(self, **kwargs):
9393 # Do not duplicate layers shared between each checkpoint.
9494 checkpoint [key ] = cps [0 ][key ]
9595 # Load single checkpoint.
96- else :
96+ elif checkpoint_path :
9797 checkpoint = torch .load (checkpoint_path , map_location = device , mmap = True )
9898
9999 # If given checkpoint is fairseq, convert to llama checkpoint.
@@ -122,8 +122,12 @@ def __init__(self, **kwargs):
122122"""
123123 )
124124
125- with open (params_path , "r" ) as f :
126- params = json .loads (f .read ())
125+ # Get optional params.
126+ params = {}
127+ if params_path :
128+ with open (params_path , "r" ) as f :
129+ params = json .loads (f .read ())
130+
127131 output_prune_map = None
128132 if self .output_prune_map_path is not None :
129133 with open (self .output_prune_map_path , "r" ) as f :
@@ -170,7 +174,11 @@ def __init__(self, **kwargs):
170174 with torch .device ("meta" ):
171175 # Model itself is loaded in default dtype, fp32.
172176 self .model_ = Transformer (model_args )
173- self .model_ .checkpoint_dtype = get_checkpoint_dtype (checkpoint )
177+ # Get checkpoint dtype.
178+ if checkpoint :
179+ self .model_ .checkpoint_dtype = get_checkpoint_dtype (checkpoint )
180+ else :
181+ self .model_ .checkpoint_dtype = None
174182
175183 if "int8" in str (checkpoint_path ):
176184 print ("Using int8 weight-only quantization!" )
@@ -244,16 +252,19 @@ def __init__(self, **kwargs):
244252 # Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
245253 # by default initialized to fp32. This is fine because every other supported type
246254 # losslessly converts to fp32, so we don't lose precision here.
247- missing , unexpected = self .model_ .load_state_dict (
248- checkpoint ,
249- strict = False ,
250- assign = True ,
251- ) # self.model_ = Transformer(gptconf)
255+ if checkpoint :
256+ missing , unexpected = self .model_ .load_state_dict (
257+ checkpoint ,
258+ strict = False ,
259+ assign = True ,
260+ ) # self.model_ = Transformer(gptconf)
261+ else :
262+ print ("Checkpoint not provided, defaulting to uninitialized weights." )
263+ self .model_ .to_empty (device = "cpu" )
252264 except RuntimeError as e :
253265 print (
254- "Could not load checkpoint into mode, defaulting to random uninitialized weights."
266+ f "Could not load checkpoint into mode and will default to uninitialized weights due to error: { e } ."
255267 )
256- print (f"Error: { e } " )
257268 # Need to provide concrete (empty) values for meta-initialized tensors for quantization.
258269 self .model_ .to_empty (device = "cpu" )
259270
0 commit comments