@@ -90,10 +90,8 @@ class ModelBase:
9090    use_temp_file : bool 
9191    lazy : bool 
9292    dry_run : bool 
93-     part_names : list [str ]
94-     is_safetensors : bool 
9593    hparams : dict [str , Any ]
96-     tensor_names :  set [str ]  |   None 
94+     model_tensors :  dict [str ,  Callable [[],  Tensor ]] 
9795    gguf_writer : gguf .GGUFWriter 
9896    model_name : str  |  None 
9997    metadata_override : Path  |  None 
@@ -137,25 +135,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
137135        self .dry_run  =  dry_run 
138136        self .remote_hf_model_id  =  remote_hf_model_id 
139137        self .sentence_transformers_dense_modules  =  sentence_transformers_dense_modules 
140-         if  remote_hf_model_id  is  not   None :
141-             self .is_safetensors  =  True 
142- 
143-             def  get_remote_tensors () ->  Iterator [tuple [str , Tensor ]]:
144-                 logger .info (f"Using remote model with HuggingFace id: { remote_hf_model_id }  " )
145-                 remote_tensors  =  gguf .utility .SafetensorRemote .get_list_tensors_hf_model (remote_hf_model_id )
146-                 self .tensor_names  =  set (name  for  name  in  remote_tensors .keys ())
147-                 for  name , remote_tensor  in  remote_tensors .items ():
148-                     yield  (name , LazyTorchTensor .from_remote_tensor (remote_tensor ))
149- 
150-             self .get_tensors  =  get_remote_tensors 
151-         else :
152-             prefix  =  "model"  if  not  self .is_mistral_format  else  "consolidated" 
153-             self .part_names  =  ModelBase .get_model_part_names (self .dir_model , prefix , ".safetensors" )
154-             self .is_safetensors  =  len (self .part_names ) >  0 
155-             if  not  self .is_safetensors :
156-                 self .part_names  =  ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
157138        self .hparams  =  ModelBase .load_hparams (self .dir_model , self .is_mistral_format ) if  hparams  is  None  else  hparams 
158-         self .tensor_names  =  None 
139+         self .model_tensors  =  self . index_tensors ( remote_hf_model_id = remote_hf_model_id ) 
159140        self .metadata_override  =  metadata_override 
160141        self .model_name  =  model_name 
161142        self .dir_model_card  =  dir_model   # overridden in convert_lora_to_gguf.py 
@@ -171,6 +152,8 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
171152                logger .info (f"choosing --outtype bf16 from first tensor type ({ first_tensor .dtype }  )" )
172153                self .ftype  =  gguf .LlamaFileType .MOSTLY_BF16 
173154
155+         self .dequant_model ()
156+ 
174157        # Configure GGUF Writer 
175158        self .gguf_writer  =  gguf .GGUFWriter (path = None , arch = gguf .MODEL_ARCH_NAMES [self .model_arch ], endianess = self .endianess , use_temp_file = self .use_temp_file ,
176159                                           split_max_tensors = split_max_tensors , split_max_size = split_max_size , dry_run = dry_run , small_first_shard = small_first_shard )
@@ -192,67 +175,215 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
192175            return  None 
193176        raise  KeyError (f"could not find any of: { keys }  " )
194177
195-     def  get_tensors (self ) ->  Iterator [tuple [str , Tensor ]]:
196-         tensor_names_from_parts : set [str ] =  set ()
178+     def  index_tensors (self , remote_hf_model_id : str  |  None  =  None ) ->  dict [str , Callable [[], Tensor ]]:
179+         tensors : dict [str , Callable [[], Tensor ]] =  {}
180+ 
181+         if  remote_hf_model_id  is  not   None :
182+             is_safetensors  =  True 
183+ 
184+             logger .info (f"Using remote model with HuggingFace id: { remote_hf_model_id }  " )
185+             remote_tensors  =  gguf .utility .SafetensorRemote .get_list_tensors_hf_model (remote_hf_model_id )
186+             for  name , remote_tensor  in  remote_tensors .items ():
187+                 tensors [name ] =  lambda  r = remote_tensor : LazyTorchTensor .from_remote_tensor (r )
188+ 
189+             return  tensors 
190+ 
191+         prefix  =  "model"  if  not  self .is_mistral_format  else  "consolidated" 
192+         part_names : list [str ] =  ModelBase .get_model_part_names (self .dir_model , prefix , ".safetensors" )
193+         is_safetensors : bool  =  len (part_names ) >  0 
194+         if  not  is_safetensors :
195+             part_names  =  ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
196+ 
197+         tensor_names_from_index : set [str ] =  set ()
197198
198199        if  not  self .is_mistral_format :
199-             index_name  =  "model.safetensors"  if  self . is_safetensors  else  "pytorch_model.bin" 
200+             index_name  =  "model.safetensors"  if  is_safetensors  else  "pytorch_model.bin" 
200201            index_name  +=  ".index.json" 
201202            index_file  =  self .dir_model  /  index_name 
202203
203204            if  index_file .is_file ():
204-                 self .tensor_names  =  set ()
205205                logger .info (f"gguf: loading model weight map from '{ index_name }  '" )
206206                with  open (index_file , "r" , encoding = "utf-8" ) as  f :
207207                    index : dict [str , Any ] =  json .load (f )
208208                    weight_map  =  index .get ("weight_map" )
209209                    if  weight_map  is  None  or  not  isinstance (weight_map , dict ):
210210                        raise  ValueError (f"Can't load 'weight_map' from { index_name !r}  " )
211-                     self . tensor_names .update (weight_map .keys ())
211+                     tensor_names_from_index .update (weight_map .keys ())
212212            else :
213-                 self .tensor_names  =  tensor_names_from_parts 
214213                weight_map  =  {}
215214        else :
216-             self .tensor_names  =  tensor_names_from_parts 
217215            weight_map  =  {}
218216
219-         for  part_name  in  self . part_names :
220-             logger .info (f"gguf: loading  model part '{ part_name }  '" )
217+         for  part_name  in  part_names :
218+             logger .info (f"gguf: indexing  model part '{ part_name }  '" )
221219            ctx : ContextManager [Any ]
222-             if  self . is_safetensors :
220+             if  is_safetensors :
223221                from  safetensors  import  safe_open 
224222                ctx  =  cast (ContextManager [Any ], safe_open (self .dir_model  /  part_name , framework = "pt" , device = "cpu" ))
225223            else :
226224                ctx  =  contextlib .nullcontext (torch .load (str (self .dir_model  /  part_name ), map_location = "cpu" , mmap = True , weights_only = True ))
227225
228226            with  ctx  as  model_part :
229-                 tensor_names_from_parts . update ( model_part . keys ()) 
227+                 assert   model_part   is   not   None 
230228
231229                for  name  in  model_part .keys ():
232-                     if  self . is_safetensors :
230+                     if  is_safetensors :
233231                        if  self .lazy :
234232                            data  =  model_part .get_slice (name )
235-                             data  =  LazyTorchTensor .from_safetensors_slice (data )
233+                             data_gen  =  lambda   data = data :  LazyTorchTensor .from_safetensors_slice (data )   # noqa: E731 
236234                        else :
237235                            data  =  model_part .get_tensor (name )
236+                             data_gen  =  lambda  data = data : data   # noqa: E731 
238237                    else :
239238                        data  =  model_part [name ]
240239                        if  self .lazy :
241-                             data  =  LazyTorchTensor .from_eager (data )
242-                     yield  name , data 
240+                             data_gen  =  lambda  data = data : LazyTorchTensor .from_eager (data )  # noqa: E731 
241+                         else :
242+                             data_gen  =  lambda  data = data : data   # noqa: E731 
243+                     tensors [name ] =  data_gen 
243244
244245        # verify tensor name presence and identify potentially missing files 
245-         if  len (tensor_names_from_parts .symmetric_difference (self .tensor_names )) >  0 :
246-             missing  =  sorted (self .tensor_names .difference (tensor_names_from_parts ))
247-             extra  =  sorted (tensor_names_from_parts .difference (self .tensor_names ))
248-             missing_files  =  sorted (set (weight_map [n ] for  n  in  missing  if  n  in  weight_map ))
249-             if  len (extra ) ==  0  and  len (missing_files ) >  0 :
250-                 raise  ValueError (f"Missing or incomplete model files: { missing_files } \n " 
251-                                  f"Missing tensors: { missing }  " )
246+         if  len (tensor_names_from_index ) >  0 :
247+             tensor_names_from_parts  =  set (tensors .keys ())
248+             if  len (tensor_names_from_parts .symmetric_difference (tensor_names_from_index )) >  0 :
249+                 missing  =  sorted (tensor_names_from_index .difference (tensor_names_from_parts ))
250+                 extra  =  sorted (tensor_names_from_parts .difference (tensor_names_from_index ))
251+                 missing_files  =  sorted (set (weight_map [n ] for  n  in  missing  if  n  in  weight_map ))
252+                 if  len (extra ) ==  0  and  len (missing_files ) >  0 :
253+                     raise  ValueError (f"Missing or incomplete model files: { missing_files } \n " 
254+                                      f"Missing tensors: { missing }  " )
255+                 else :
256+                     raise  ValueError ("Mismatch between weight map and model parts for tensor names:\n " 
257+                                      f"Missing tensors: { missing } \n " 
258+                                      f"Extra tensors: { extra }  " )
259+ 
260+         return  tensors 
261+ 
262+     def  dequant_model (self ):
263+         tensors_to_remove : list [str ] =  []
264+         new_tensors : dict [str , Callable [[], Tensor ]] =  {}
265+ 
266+         if  (quant_config  :=  self .hparams .get ("quantization_config" )) and  isinstance (quant_config , dict ):
267+             quant_method  =  quant_config .get ("quant_method" )
268+ 
269+             def  dequant_bitnet (weight : Tensor , scale : Tensor ) ->  Tensor :
270+                 weight  =  weight .view (torch .uint8 )
271+                 orig_shape  =  weight .shape 
272+ 
273+                 shift  =  torch .tensor ([0 , 2 , 4 , 6 ], dtype = torch .uint8 ).reshape ((4 , * (1  for  _  in  range (len (orig_shape )))))
274+                 data  =  weight .unsqueeze (0 ).expand ((4 , * orig_shape )) >>  shift 
275+                 data  =  data  &  3 
276+                 data  =  (data .float () -  1 ).reshape ((orig_shape [0 ] *  4 , * orig_shape [1 :]))
277+ 
278+                 # The scale is inverted 
279+                 return  data  /  scale .float ()
280+ 
281+             def  dequant_simple (weight : Tensor , scale : Tensor ) ->  Tensor :
282+                 scale  =  scale .float ()
283+ 
284+                 if  (weight_block_size  :=  quant_config .get ("weight_block_size" )):
285+                     # TODO: make sure it's a list of integers 
286+                     for  i , size  in  enumerate (weight_block_size ):
287+                         scale  =  scale .repeat_interleave (size , i )
288+                 # unpad the scale (e.g. when the tensor size isn't a multiple of the block size) 
289+                 scale  =  scale [tuple (slice (0 , size ) for  size  in  weight .shape )]
290+ 
291+                 return  weight .float () *  scale 
292+ 
293+             # ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476 
294+             def  dequant_gptq (g_idx : Tensor , qweight : Tensor , qzeros : Tensor , scales : Tensor ) ->  Tensor :
295+                 bits  =  quant_config ["bits" ]
296+                 assert  bits  in  (2 , 3 , 4 , 8 )
297+                 assert  qweight .dtype  ==  qzeros .dtype 
298+                 maxq  =  (2  **  bits ) -  1 
299+                 weight  =  None 
300+                 zeros  =  None 
301+                 pack_dtype_bits  =  qweight .dtype .itemsize  *  8 
302+ 
303+                 if  bits  in  [2 , 4 , 8 ]:
304+                     pack_factor  =  pack_dtype_bits  //  bits 
305+                     wf  =  torch .tensor (list (range (0 , pack_dtype_bits , bits )), dtype = torch .int32 ).unsqueeze (0 )
306+                     if  self .lazy :
307+                         wf  =  LazyTorchTensor .from_eager (wf )
308+ 
309+                     zeros  =  torch .bitwise_right_shift (
310+                         qzeros .unsqueeze (2 ).expand (- 1 , - 1 , pack_factor ),
311+                         wf .unsqueeze (0 )
312+                     ).to (torch .int16  if  bits  ==  8  else  torch .int8 )
313+                     zeros  =  torch .bitwise_and (zeros , maxq ).reshape (scales .shape )
314+ 
315+                     weight  =  torch .bitwise_and (
316+                         torch .bitwise_right_shift (
317+                             qweight .unsqueeze (1 ).expand (- 1 , pack_factor , - 1 ),
318+                             wf .unsqueeze (- 1 )
319+                         ).to (torch .int16  if  bits  ==  8  else  torch .int8 ),
320+                         maxq 
321+                     )
322+                 elif  bits  ==  3 :
323+                     raise  NotImplementedError ("3-bit gptq dequantization is not yet implemented" )
324+ 
325+                 assert  weight  is  not   None 
326+                 assert  zeros  is  not   None 
327+ 
328+                 weight  =  weight .reshape (weight .shape [0 ] *  weight .shape [1 ], weight .shape [2 ])
329+ 
330+                 # gptq_v2 doesn't need to offset zeros 
331+                 if  quant_config .get ("checkpoint_format" , "gptq" ) ==  "gptq" :
332+                     zeros  +=  1 
333+ 
334+                 return  (scales [g_idx ].float () *  (weight  -  zeros [g_idx ]).float ()).T 
335+ 
336+             if  quant_method  ==  "bitnet" :
337+                 for  name  in  self .model_tensors .keys ():
338+                     if  name .endswith (".weight_scale" ):
339+                         weight_name  =  name .removesuffix ("_scale" )
340+                         w  =  self .model_tensors [weight_name ]
341+                         s  =  self .model_tensors [name ]
342+                         self .model_tensors [weight_name ] =  lambda  w = w , s = s : dequant_bitnet (w (), s ())
343+                         tensors_to_remove .append (name )
344+             elif  quant_method  ==  "fp8" :
345+                 for  name  in  self .model_tensors .keys ():
346+                     if  name .endswith (".weight_scale_inv" ):
347+                         weight_name  =  name .removesuffix ("_scale_inv" )
348+                         w  =  self .model_tensors [weight_name ]
349+                         s  =  self .model_tensors [name ]
350+                         self .model_tensors [weight_name ] =  lambda  w = w , s = s : dequant_simple (w (), s ())
351+                         tensors_to_remove .append (name )
352+             elif  quant_method  ==  "gptq" :
353+                 for  name  in  self .model_tensors .keys ():
354+                     if  name .endswith (".qweight" ):
355+                         base_name  =  name .removesuffix (".qweight" )
356+                         g_idx  =  self .model_tensors [base_name  +  ".g_idx" ]
357+                         qweight  =  self .model_tensors [base_name  +  ".qweight" ]
358+                         qzeros  =  self .model_tensors [base_name  +  ".qzeros" ]
359+                         scales  =  self .model_tensors [base_name  +  ".scales" ]
360+                         new_tensors [base_name  +  ".weight" ] =  (
361+                             lambda  g = g_idx , z = qzeros , w = qweight , s = scales : dequant_gptq (
362+                                 g (), w (), z (), s ()
363+                             )
364+                         )
365+                         tensors_to_remove  +=  [
366+                             base_name  +  n 
367+                             for  n  in  (
368+                                 ".g_idx" ,
369+                                 ".qzeros" ,
370+                                 ".qweight" ,
371+                                 ".scales" ,
372+                             )
373+                         ]
252374            else :
253-                 raise  ValueError ("Mismatch between weight map and model parts for tensor names:\n " 
254-                                  f"Missing tensors: { missing } \n " 
255-                                  f"Extra tensors: { extra }  " )
375+                 raise  NotImplementedError (f"Quant method is not yet supported: { quant_method !r}  " )
376+ 
377+         for  name  in  tensors_to_remove :
378+             if  name  in  self .model_tensors :
379+                 del  self .model_tensors [name ]
380+ 
381+         for  name , value  in  new_tensors .items ():
382+             self .model_tensors [name ] =  value 
383+ 
384+     def  get_tensors (self ) ->  Iterator [tuple [str , Tensor ]]:
385+         for  name , gen  in  self .model_tensors .items ():
386+             yield  name , gen ()
256387
257388    def  format_tensor_name (self , key : gguf .MODEL_TENSOR , bid : int  |  None  =  None , suffix : str  =  ".weight" ) ->  str :
258389        if  key  not  in   gguf .MODEL_TENSORS [self .model_arch ]:
@@ -4381,27 +4512,6 @@ def set_gguf_parameters(self):
43814512        self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .LINEAR )
43824513        self .gguf_writer .add_rope_scaling_factor (1.0 )
43834514
4384-     _has_tok_embd  =  False 
4385- 
4386-     def  modify_tensors (self , data_torch : Tensor , name : str , bid : int  |  None ) ->  Iterable [tuple [str , Tensor ]]:
4387-         del  bid   # unused 
4388- 
4389-         output_name  =  self .format_tensor_name (gguf .MODEL_TENSOR .OUTPUT )
4390-         tok_embd_name  =  self .format_tensor_name (gguf .MODEL_TENSOR .TOKEN_EMBD )
4391- 
4392-         new_name  =  self .map_tensor_name (name )
4393- 
4394-         # assuming token_embd.weight is seen before output.weight 
4395-         if  not  self ._has_tok_embd  and  new_name  ==  self .format_tensor_name (gguf .MODEL_TENSOR .OUTPUT ):
4396-             # even though the tensor file(s) does not contain the word embeddings they are still in the weight map 
4397-             if  self .tensor_names  and  "transformer.wte.weight"  in  self .tensor_names :
4398-                 logger .debug (f"{ tok_embd_name }   not found before { output_name }  , assuming they are tied" )
4399-                 self .tensor_names .remove ("transformer.wte.weight" )
4400-         elif  new_name  ==  tok_embd_name :
4401-             self ._has_tok_embd  =  True 
4402- 
4403-         return  [(new_name , data_torch )]
4404- 
44054515
44064516@ModelBase .register ("InternLM2ForCausalLM" ) 
44074517class  InternLM2Model (TextModel ):
0 commit comments