@@ -196,26 +196,25 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
196196 logger .info (f"gguf: indexing model part '{ part_name } '" )
197197 ctx : ContextManager [Any ]
198198 if is_safetensors :
199- from safetensors import safe_open
200- ctx = cast (ContextManager [Any ], safe_open (self .dir_model / part_name , framework = "pt" , device = "cpu" ))
199+ ctx = cast (ContextManager [Any ], gguf .utility .SafetensorsLocal (self .dir_model / part_name ))
201200 else :
202201 ctx = contextlib .nullcontext (torch .load (str (self .dir_model / part_name ), map_location = "cpu" , mmap = True , weights_only = True ))
203202 with ctx as model_part :
204203 assert model_part is not None
205204 for name in model_part .keys ():
206205 if is_safetensors :
206+ data : gguf .utility .LocalTensor = model_part [name ]
207207 if self .lazy :
208- data = model_part .get_slice (name )
209- data_gen = lambda data = data : LazyTorchTensor .from_safetensors_slice (data ) # noqa: E731
208+ data_gen = lambda data = data : LazyTorchTensor .from_local_tensor (data ) # noqa: E731
210209 else :
211- data = model_part . get_tensor ( name )
212- data_gen = lambda data = data : data # noqa: E731
210+ dtype = LazyTorchTensor . _dtype_str_map [ data . dtype ]
211+ data_gen = lambda data = data , dtype = dtype : torch . from_numpy ( data . mmap_bytes ()). view ( dtype ). reshape ( data . shape ) # noqa: E731
213212 else :
214- data = model_part [name ]
213+ data_torch : Tensor = model_part [name ]
215214 if self .lazy :
216- data_gen = lambda data = data : LazyTorchTensor .from_eager (data ) # noqa: E731
215+ data_gen = lambda data = data_torch : LazyTorchTensor .from_eager (data ) # noqa: E731
217216 else :
218- data_gen = lambda data = data : data # noqa: E731
217+ data_gen = lambda data = data_torch : data # noqa: E731
219218 tensors [name ] = data_gen
220219 # verify tensor name presence and identify potentially missing files
221220 if len (tensor_names_from_index ) > 0 :
@@ -249,14 +248,15 @@ def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
249248 # The scale is inverted
250249 return data / scale .float ()
251250
252- def dequant_simple (weight : Tensor , scale : Tensor ) -> Tensor :
251+ def dequant_simple (weight : Tensor , scale : Tensor , block_size : Sequence [ int ] | None = None ) -> Tensor :
253252 scale = scale .float ()
254- if ( weight_block_size := quant_config . get ( "weight_block_size" )):
255- # TODO: make sure it's a list of integers
256- for i , size in enumerate (weight_block_size ):
253+
254+ if block_size is not None :
255+ for i , size in enumerate (block_size ):
257256 scale = scale .repeat_interleave (size , i )
258- # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
259- scale = scale [tuple (slice (0 , size ) for size in weight .shape )]
257+ # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
258+ scale = scale [tuple (slice (0 , size ) for size in weight .shape )]
259+
260260 return weight .float () * scale
261261
262262 # ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
@@ -294,6 +294,41 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
294294 if quant_config .get ("checkpoint_format" , "gptq" ) == "gptq" :
295295 zeros += 1
296296 return (scales [g_idx ].float () * (weight - zeros [g_idx ]).float ()).T
297+
298+ def dequant_packed (w : Tensor , scale : Tensor , shape_tensor : Tensor , zero_point : Tensor | None , num_bits : int , group_size : int ):
299+ assert w .dtype == torch .int32
300+ shape = tuple (shape_tensor .tolist ())
301+ assert len (shape ) == 2
302+ mask = (1 << num_bits ) - 1
303+
304+ shifts = torch .arange (0 , 32 - (num_bits - 1 ), num_bits , dtype = torch .int32 )
305+ if self .lazy :
306+ shifts = LazyTorchTensor .from_eager (shifts )
307+
308+ if zero_point is None :
309+ offset = 1 << (num_bits - 1 )
310+ else :
311+ assert len (zero_point .shape ) == 2
312+ offset = (zero_point .unsqueeze (1 ) >> shifts .reshape (1 , - 1 , 1 )) & mask
313+ offset = offset .reshape (- 1 , zero_point .shape [1 ])
314+ # trim padding, and prepare for broadcast
315+ # NOTE: the zero-point is packed along dim 0
316+ offset = offset [:shape [0 ], :].unsqueeze (- 1 )
317+
318+ # extract values
319+ # NOTE: the weights are packed along dim 1
320+ unpacked = (w .unsqueeze (- 1 ) >> shifts .reshape (1 , 1 , - 1 )) & mask
321+ unpacked = unpacked .reshape (shape [0 ], - 1 )
322+
323+ # trim padding
324+ unpacked = unpacked [:, :shape [1 ]]
325+
326+ # prepare for broadcast of the scale
327+ unpacked = unpacked .reshape (shape [0 ], (unpacked .shape [- 1 ] + group_size - 1 ) // group_size , group_size )
328+ unpacked = unpacked - offset
329+
330+ return (unpacked * scale .unsqueeze (- 1 ).float ()).reshape (shape )
331+
297332 if quant_method == "bitnet" :
298333 for name in self .model_tensors .keys ():
299334 if name .endswith (".weight_scale" ):
@@ -303,12 +338,13 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
303338 self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_bitnet (w (), s ())
304339 tensors_to_remove .append (name )
305340 elif quant_method == "fp8" :
341+ block_size = quant_config .get ("weight_block_size" )
306342 for name in self .model_tensors .keys ():
307343 if name .endswith (".weight_scale_inv" ):
308344 weight_name = name .removesuffix ("_scale_inv" )
309345 w = self .model_tensors [weight_name ]
310346 s = self .model_tensors [name ]
311- self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_simple (w (), s ())
347+ self .model_tensors [weight_name ] = lambda w = w , s = s , bs = block_size : dequant_simple (w (), s (), bs )
312348 tensors_to_remove .append (name )
313349 elif quant_method == "gptq" :
314350 for name in self .model_tensors .keys ():
@@ -332,11 +368,56 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
332368 ".scales" ,
333369 )
334370 ]
371+ elif quant_method == "compressed-tensors" :
372+ quant_format = quant_config ["format" ]
373+ groups = quant_config ["config_groups" ]
374+ if len (groups ) > 1 :
375+ raise NotImplementedError ("Can't handle multiple config groups for compressed-tensors yet" )
376+ weight_config = tuple (groups .values ())[0 ]["weights" ]
377+
378+ if quant_format == "float-quantized" or quant_format == "int-quantized" or quant_format == "naive-quantized" :
379+ block_size = weight_config .get ("block_structure" , None )
380+ strategy = weight_config .get ("strategy" )
381+ assert strategy == "channel" or strategy == "block"
382+ assert weight_config .get ("group_size" ) is None # didn't find a model using this yet
383+ for name in self .model_tensors .keys ():
384+ if name .endswith (".weight_scale" ):
385+ weight_name = name .removesuffix ("_scale" )
386+ w = self .model_tensors [weight_name ]
387+ s = self .model_tensors [name ]
388+ self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_simple (w (), s (), block_size )
389+ tensors_to_remove .append (name )
390+ elif quant_format == "pack-quantized" :
391+ assert weight_config .get ("strategy" ) == "group"
392+ assert weight_config .get ("type" , "int" ) == "int"
393+ num_bits = weight_config .get ("num_bits" )
394+ group_size = weight_config .get ("group_size" )
395+ assert isinstance (num_bits , int )
396+ assert isinstance (group_size , int )
397+ for name in self .model_tensors .keys ():
398+ if name .endswith (".weight_packed" ):
399+ base_name = name .removesuffix ("_packed" )
400+ w = self .model_tensors [name ]
401+ scale = self .model_tensors [base_name + "_scale" ]
402+ shape = self .model_tensors [base_name + "_shape" ]
403+ zero_point = self .model_tensors .get (base_name + "_zero_point" , lambda : None )
404+ new_tensors [base_name ] = (
405+ lambda w = w , scale = scale , shape = shape , zero_point = zero_point : dequant_packed (
406+ w (), scale (), shape (), zero_point (), num_bits , group_size ,
407+ )
408+ )
409+ tensors_to_remove += [base_name + n for n in ("_packed" , "_shape" , "_scale" )]
410+ if (base_name + "_zero_point" ) in self .model_tensors :
411+ tensors_to_remove .append (base_name + "_zero_point" )
412+ else :
413+ raise NotImplementedError (f"Quant format { quant_format !r} for method { quant_method !r} is not yet supported" )
335414 else :
336415 raise NotImplementedError (f"Quant method is not yet supported: { quant_method !r} " )
416+
337417 for name in tensors_to_remove :
338418 if name in self .model_tensors :
339419 del self .model_tensors [name ]
420+
340421 for name , value in new_tensors .items ():
341422 self .model_tensors [name ] = value
342423
@@ -940,6 +1021,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
9401021 if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756" :
9411022 # ref: https://huggingface.co/JetBrains/Mellum-4b-base
9421023 res = "mellum"
1024+ if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df" :
1025+ # ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
1026+ res = "afmoe"
9431027 if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206" :
9441028 # ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
9451029 res = "bailingmoe2"
@@ -990,10 +1074,11 @@ def _set_vocab_qwen(self):
9901074 vocab_size = hparams ["vocab_size" ]
9911075 assert max (tokenizer .get_vocab ().values ()) < vocab_size
9921076 tokpre = self .get_vocab_base_pre (tokenizer )
993- QwenModel = _get_qwen_model ()
1077+
9941078 merges = []
9951079 vocab = {}
9961080 mergeable_ranks = tokenizer .mergeable_ranks
1081+ QwenModel = _get_qwen_model ()
9971082 for token , rank in mergeable_ranks .items ():
9981083 vocab [QwenModel .token_bytes_to_string (token )] = rank
9991084 if len (token ) == 1 :
0 commit comments