@@ -333,6 +333,35 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
333333
334334 return (scales [g_idx ].float () * (weight - zeros [g_idx ]).float ()).T
335335
336+ # ref: https://github.com/vllm-project/compressed-tensors/blob/52792be02ec09e59f3517104e755a02d0e003fbb/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py
337+ def dequant_compressed_tensor (weight : Tensor , scale : Tensor ) -> Tensor :
338+ scale = scale .float ()
339+ weights_config = quant_config ["config_groups" ]["group_0" ]["weights" ]
340+ group_size = weights_config ["group_size" ]
341+ num_bits = weights_config ["num_bits" ]
342+ # only tested with https://huggingface.co/moonshotai/Kimi-K2-Thinking/blob/main/config.json
343+ # TODO: extend this if other configurations are needed
344+ assert (group_size == 32 )
345+ assert (num_bits == 4 )
346+ assert (quant_config ["format" ] == "pack-quantized" )
347+
348+ pack_factor = group_size // num_bits
349+ mask = (1 << num_bits ) - 1
350+ unpacked = torch .zeros (
351+ (weight .shape [0 ], weight .shape [1 ] * pack_factor ),
352+ device = weight .device ,
353+ dtype = torch .int32 ,
354+ )
355+ for i in range (pack_factor ):
356+ unpacked [:, i ::pack_factor ] = (weight >> (num_bits * i )) & mask
357+ # TODO: may need to unpad
358+ unpacked = unpacked - (mask + 1 ) // 2 # convert uint4 to int4 (shift scale)
359+ scale = scale .unsqueeze (2 )
360+ unpacked = unpacked .to (torch .float32 )
361+ unpacked = unpacked .reshape (- 1 , unpacked .shape [1 ] // group_size , group_size )
362+ dequantized = (unpacked * scale ).reshape (- 1 , unpacked .shape [1 ] * group_size )
363+ return dequantized
364+
336365 if quant_method == "bitnet" :
337366 for name in self .model_tensors .keys ():
338367 if name .endswith (".weight_scale" ):
@@ -371,6 +400,24 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
371400 ".scales" ,
372401 )
373402 ]
403+ elif quant_method == "compressed-tensors" :
404+ weight_block_size = quant_config ["config_groups" ]["group_0" ]["weights" ]["group_size" ]
405+ quant_config ["weight_block_size" ] = weight_block_size
406+ for name in self .model_tensors .keys ():
407+ if name .endswith ("_packed" ):
408+ base_name = name .removesuffix ("_packed" )
409+ packed = self .model_tensors [base_name + "_packed" ]
410+ scale = self .model_tensors [base_name + "_scale" ]
411+ # TODO: use _shape for unpadding if necessary
412+ new_tensors [base_name ] = lambda p = packed , s = scale : dequant_compressed_tensor (p (), s ())
413+ tensors_to_remove += [
414+ base_name + n
415+ for n in (
416+ "_packed" ,
417+ "_scale" ,
418+ "_shape" ,
419+ )
420+ ]
374421 else :
375422 raise NotImplementedError (f"Quant method is not yet supported: { quant_method !r} " )
376423
0 commit comments