@@ -90,10 +90,8 @@ class ModelBase:
9090 use_temp_file : bool
9191 lazy : bool
9292 dry_run : bool
93- part_names : list [str ]
94- is_safetensors : bool
9593 hparams : dict [str , Any ]
96- tensor_names : set [str ] | None
94+ model_tensors : dict [str , Callable [[], Tensor ]]
9795 gguf_writer : gguf .GGUFWriter
9896 model_name : str | None
9997 metadata_override : Path | None
@@ -137,25 +135,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
137135 self .dry_run = dry_run
138136 self .remote_hf_model_id = remote_hf_model_id
139137 self .sentence_transformers_dense_modules = sentence_transformers_dense_modules
140- if remote_hf_model_id is not None :
141- self .is_safetensors = True
142-
143- def get_remote_tensors () -> Iterator [tuple [str , Tensor ]]:
144- logger .info (f"Using remote model with HuggingFace id: { remote_hf_model_id } " )
145- remote_tensors = gguf .utility .SafetensorRemote .get_list_tensors_hf_model (remote_hf_model_id )
146- self .tensor_names = set (name for name in remote_tensors .keys ())
147- for name , remote_tensor in remote_tensors .items ():
148- yield (name , LazyTorchTensor .from_remote_tensor (remote_tensor ))
149-
150- self .get_tensors = get_remote_tensors
151- else :
152- prefix = "model" if not self .is_mistral_format else "consolidated"
153- self .part_names = ModelBase .get_model_part_names (self .dir_model , prefix , ".safetensors" )
154- self .is_safetensors = len (self .part_names ) > 0
155- if not self .is_safetensors :
156- self .part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
157138 self .hparams = ModelBase .load_hparams (self .dir_model , self .is_mistral_format ) if hparams is None else hparams
158- self .tensor_names = None
139+ self .model_tensors = self . index_tensors ( remote_hf_model_id = remote_hf_model_id )
159140 self .metadata_override = metadata_override
160141 self .model_name = model_name
161142 self .dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
@@ -171,6 +152,8 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
171152 logger .info (f"choosing --outtype bf16 from first tensor type ({ first_tensor .dtype } )" )
172153 self .ftype = gguf .LlamaFileType .MOSTLY_BF16
173154
155+ self .dequant_model ()
156+
174157 # Configure GGUF Writer
175158 self .gguf_writer = gguf .GGUFWriter (path = None , arch = gguf .MODEL_ARCH_NAMES [self .model_arch ], endianess = self .endianess , use_temp_file = self .use_temp_file ,
176159 split_max_tensors = split_max_tensors , split_max_size = split_max_size , dry_run = dry_run , small_first_shard = small_first_shard )
@@ -192,67 +175,215 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
192175 return None
193176 raise KeyError (f"could not find any of: { keys } " )
194177
195- def get_tensors (self ) -> Iterator [tuple [str , Tensor ]]:
196- tensor_names_from_parts : set [str ] = set ()
178+ def index_tensors (self , remote_hf_model_id : str | None = None ) -> dict [str , Callable [[], Tensor ]]:
179+ tensors : dict [str , Callable [[], Tensor ]] = {}
180+
181+ if remote_hf_model_id is not None :
182+ is_safetensors = True
183+
184+ logger .info (f"Using remote model with HuggingFace id: { remote_hf_model_id } " )
185+ remote_tensors = gguf .utility .SafetensorRemote .get_list_tensors_hf_model (remote_hf_model_id )
186+ for name , remote_tensor in remote_tensors .items ():
187+ tensors [name ] = lambda r = remote_tensor : LazyTorchTensor .from_remote_tensor (r )
188+
189+ return tensors
190+
191+ prefix = "model" if not self .is_mistral_format else "consolidated"
192+ part_names : list [str ] = ModelBase .get_model_part_names (self .dir_model , prefix , ".safetensors" )
193+ is_safetensors : bool = len (part_names ) > 0
194+ if not is_safetensors :
195+ part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
196+
197+ tensor_names_from_index : set [str ] = set ()
197198
198199 if not self .is_mistral_format :
199- index_name = "model.safetensors" if self . is_safetensors else "pytorch_model.bin"
200+ index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
200201 index_name += ".index.json"
201202 index_file = self .dir_model / index_name
202203
203204 if index_file .is_file ():
204- self .tensor_names = set ()
205205 logger .info (f"gguf: loading model weight map from '{ index_name } '" )
206206 with open (index_file , "r" , encoding = "utf-8" ) as f :
207207 index : dict [str , Any ] = json .load (f )
208208 weight_map = index .get ("weight_map" )
209209 if weight_map is None or not isinstance (weight_map , dict ):
210210 raise ValueError (f"Can't load 'weight_map' from { index_name !r} " )
211- self . tensor_names .update (weight_map .keys ())
211+ tensor_names_from_index .update (weight_map .keys ())
212212 else :
213- self .tensor_names = tensor_names_from_parts
214213 weight_map = {}
215214 else :
216- self .tensor_names = tensor_names_from_parts
217215 weight_map = {}
218216
219- for part_name in self . part_names :
220- logger .info (f"gguf: loading model part '{ part_name } '" )
217+ for part_name in part_names :
218+ logger .info (f"gguf: indexing model part '{ part_name } '" )
221219 ctx : ContextManager [Any ]
222- if self . is_safetensors :
220+ if is_safetensors :
223221 from safetensors import safe_open
224222 ctx = cast (ContextManager [Any ], safe_open (self .dir_model / part_name , framework = "pt" , device = "cpu" ))
225223 else :
226224 ctx = contextlib .nullcontext (torch .load (str (self .dir_model / part_name ), map_location = "cpu" , mmap = True , weights_only = True ))
227225
228226 with ctx as model_part :
229- tensor_names_from_parts . update ( model_part . keys ())
227+ assert model_part is not None
230228
231229 for name in model_part .keys ():
232- if self . is_safetensors :
230+ if is_safetensors :
233231 if self .lazy :
234232 data = model_part .get_slice (name )
235- data = LazyTorchTensor .from_safetensors_slice (data )
233+ data_gen = lambda data = data : LazyTorchTensor .from_safetensors_slice (data ) # noqa: E731
236234 else :
237235 data = model_part .get_tensor (name )
236+ data_gen = lambda data = data : data # noqa: E731
238237 else :
239238 data = model_part [name ]
240239 if self .lazy :
241- data = LazyTorchTensor .from_eager (data )
242- yield name , data
240+ data_gen = lambda data = data : LazyTorchTensor .from_eager (data ) # noqa: E731
241+ else :
242+ data_gen = lambda data = data : data # noqa: E731
243+ tensors [name ] = data_gen
243244
244245 # verify tensor name presence and identify potentially missing files
245- if len (tensor_names_from_parts .symmetric_difference (self .tensor_names )) > 0 :
246- missing = sorted (self .tensor_names .difference (tensor_names_from_parts ))
247- extra = sorted (tensor_names_from_parts .difference (self .tensor_names ))
248- missing_files = sorted (set (weight_map [n ] for n in missing if n in weight_map ))
249- if len (extra ) == 0 and len (missing_files ) > 0 :
250- raise ValueError (f"Missing or incomplete model files: { missing_files } \n "
251- f"Missing tensors: { missing } " )
246+ if len (tensor_names_from_index ) > 0 :
247+ tensor_names_from_parts = set (tensors .keys ())
248+ if len (tensor_names_from_parts .symmetric_difference (tensor_names_from_index )) > 0 :
249+ missing = sorted (tensor_names_from_index .difference (tensor_names_from_parts ))
250+ extra = sorted (tensor_names_from_parts .difference (tensor_names_from_index ))
251+ missing_files = sorted (set (weight_map [n ] for n in missing if n in weight_map ))
252+ if len (extra ) == 0 and len (missing_files ) > 0 :
253+ raise ValueError (f"Missing or incomplete model files: { missing_files } \n "
254+ f"Missing tensors: { missing } " )
255+ else :
256+ raise ValueError ("Mismatch between weight map and model parts for tensor names:\n "
257+ f"Missing tensors: { missing } \n "
258+ f"Extra tensors: { extra } " )
259+
260+ return tensors
261+
262+ def dequant_model (self ):
263+ tensors_to_remove : list [str ] = []
264+ new_tensors : dict [str , Callable [[], Tensor ]] = {}
265+
266+ if (quant_config := self .hparams .get ("quantization_config" )) and isinstance (quant_config , dict ):
267+ quant_method = quant_config .get ("quant_method" )
268+
269+ def dequant_bitnet (weight : Tensor , scale : Tensor ) -> Tensor :
270+ weight = weight .view (torch .uint8 )
271+ orig_shape = weight .shape
272+
273+ shift = torch .tensor ([0 , 2 , 4 , 6 ], dtype = torch .uint8 ).reshape ((4 , * (1 for _ in range (len (orig_shape )))))
274+ data = weight .unsqueeze (0 ).expand ((4 , * orig_shape )) >> shift
275+ data = data & 3
276+ data = (data .float () - 1 ).reshape ((orig_shape [0 ] * 4 , * orig_shape [1 :]))
277+
278+ # The scale is inverted
279+ return data / scale .float ()
280+
281+ def dequant_simple (weight : Tensor , scale : Tensor ) -> Tensor :
282+ scale = scale .float ()
283+
284+ if (weight_block_size := quant_config .get ("weight_block_size" )):
285+ # TODO: make sure it's a list of integers
286+ for i , size in enumerate (weight_block_size ):
287+ scale = scale .repeat_interleave (size , i )
288+ # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
289+ scale = scale [tuple (slice (0 , size ) for size in weight .shape )]
290+
291+ return weight .float () * scale
292+
293+ # ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
294+ def dequant_gptq (g_idx : Tensor , qweight : Tensor , qzeros : Tensor , scales : Tensor ) -> Tensor :
295+ bits = quant_config ["bits" ]
296+ assert bits in (2 , 3 , 4 , 8 )
297+ assert qweight .dtype == qzeros .dtype
298+ maxq = (2 ** bits ) - 1
299+ weight = None
300+ zeros = None
301+ pack_dtype_bits = qweight .dtype .itemsize * 8
302+
303+ if bits in [2 , 4 , 8 ]:
304+ pack_factor = pack_dtype_bits // bits
305+ wf = torch .tensor (list (range (0 , pack_dtype_bits , bits )), dtype = torch .int32 ).unsqueeze (0 )
306+ if self .lazy :
307+ wf = LazyTorchTensor .from_eager (wf )
308+
309+ zeros = torch .bitwise_right_shift (
310+ qzeros .unsqueeze (2 ).expand (- 1 , - 1 , pack_factor ),
311+ wf .unsqueeze (0 )
312+ ).to (torch .int16 if bits == 8 else torch .int8 )
313+ zeros = torch .bitwise_and (zeros , maxq ).reshape (scales .shape )
314+
315+ weight = torch .bitwise_and (
316+ torch .bitwise_right_shift (
317+ qweight .unsqueeze (1 ).expand (- 1 , pack_factor , - 1 ),
318+ wf .unsqueeze (- 1 )
319+ ).to (torch .int16 if bits == 8 else torch .int8 ),
320+ maxq
321+ )
322+ elif bits == 3 :
323+ raise NotImplementedError ("3-bit gptq dequantization is not yet implemented" )
324+
325+ assert weight is not None
326+ assert zeros is not None
327+
328+ weight = weight .reshape (weight .shape [0 ] * weight .shape [1 ], weight .shape [2 ])
329+
330+ # gptq_v2 doesn't need to offset zeros
331+ if quant_config .get ("checkpoint_format" , "gptq" ) == "gptq" :
332+ zeros += 1
333+
334+ return (scales [g_idx ].float () * (weight - zeros [g_idx ]).float ()).T
335+
336+ if quant_method == "bitnet" :
337+ for name in self .model_tensors .keys ():
338+ if name .endswith (".weight_scale" ):
339+ weight_name = name .removesuffix ("_scale" )
340+ w = self .model_tensors [weight_name ]
341+ s = self .model_tensors [name ]
342+ self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_bitnet (w (), s ())
343+ tensors_to_remove .append (name )
344+ elif quant_method == "fp8" :
345+ for name in self .model_tensors .keys ():
346+ if name .endswith (".weight_scale_inv" ):
347+ weight_name = name .removesuffix ("_scale_inv" )
348+ w = self .model_tensors [weight_name ]
349+ s = self .model_tensors [name ]
350+ self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_simple (w (), s ())
351+ tensors_to_remove .append (name )
352+ elif quant_method == "gptq" :
353+ for name in self .model_tensors .keys ():
354+ if name .endswith (".qweight" ):
355+ base_name = name .removesuffix (".qweight" )
356+ g_idx = self .model_tensors [base_name + ".g_idx" ]
357+ qweight = self .model_tensors [base_name + ".qweight" ]
358+ qzeros = self .model_tensors [base_name + ".qzeros" ]
359+ scales = self .model_tensors [base_name + ".scales" ]
360+ new_tensors [base_name + ".weight" ] = (
361+ lambda g = g_idx , z = qzeros , w = qweight , s = scales : dequant_gptq (
362+ g (), w (), z (), s ()
363+ )
364+ )
365+ tensors_to_remove += [
366+ base_name + n
367+ for n in (
368+ ".g_idx" ,
369+ ".qzeros" ,
370+ ".qweight" ,
371+ ".scales" ,
372+ )
373+ ]
252374 else :
253- raise ValueError ("Mismatch between weight map and model parts for tensor names:\n "
254- f"Missing tensors: { missing } \n "
255- f"Extra tensors: { extra } " )
375+ raise NotImplementedError (f"Quant method is not yet supported: { quant_method !r} " )
376+
377+ for name in tensors_to_remove :
378+ if name in self .model_tensors :
379+ del self .model_tensors [name ]
380+
381+ for name , value in new_tensors .items ():
382+ self .model_tensors [name ] = value
383+
384+ def get_tensors (self ) -> Iterator [tuple [str , Tensor ]]:
385+ for name , gen in self .model_tensors .items ():
386+ yield name , gen ()
256387
257388 def format_tensor_name (self , key : gguf .MODEL_TENSOR , bid : int | None = None , suffix : str = ".weight" ) -> str :
258389 if key not in gguf .MODEL_TENSORS [self .model_arch ]:
@@ -4381,27 +4512,6 @@ def set_gguf_parameters(self):
43814512 self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .LINEAR )
43824513 self .gguf_writer .add_rope_scaling_factor (1.0 )
43834514
4384- _has_tok_embd = False
4385-
4386- def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4387- del bid # unused
4388-
4389- output_name = self .format_tensor_name (gguf .MODEL_TENSOR .OUTPUT )
4390- tok_embd_name = self .format_tensor_name (gguf .MODEL_TENSOR .TOKEN_EMBD )
4391-
4392- new_name = self .map_tensor_name (name )
4393-
4394- # assuming token_embd.weight is seen before output.weight
4395- if not self ._has_tok_embd and new_name == self .format_tensor_name (gguf .MODEL_TENSOR .OUTPUT ):
4396- # even though the tensor file(s) does not contain the word embeddings they are still in the weight map
4397- if self .tensor_names and "transformer.wte.weight" in self .tensor_names :
4398- logger .debug (f"{ tok_embd_name } not found before { output_name } , assuming they are tied" )
4399- self .tensor_names .remove ("transformer.wte.weight" )
4400- elif new_name == tok_embd_name :
4401- self ._has_tok_embd = True
4402-
4403- return [(new_name , data_torch )]
4404-
44054515
44064516@ModelBase .register ("InternLM2ForCausalLM" )
44074517class InternLM2Model (TextModel ):
0 commit comments