Skip to content

Commit 1bd57a3

Browse files
committed
convert: add dequant function for compressed_tensor (kimi-k2-thinking)
1 parent 7f09a68 commit 1bd57a3

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,35 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
333333

334334
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
335335

336+
# ref: https://github.com/vllm-project/compressed-tensors/blob/52792be02ec09e59f3517104e755a02d0e003fbb/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py
337+
def dequant_compressed_tensor(weight: Tensor, scale: Tensor) -> Tensor:
338+
scale = scale.float()
339+
weights_config = quant_config["config_groups"]["group_0"]["weights"]
340+
group_size = weights_config["group_size"]
341+
num_bits = weights_config["num_bits"]
342+
# only tested with https://huggingface.co/moonshotai/Kimi-K2-Thinking/blob/main/config.json
343+
# TODO: extend this if other configurations are needed
344+
assert(group_size == 32)
345+
assert(num_bits == 4)
346+
assert(quant_config["format"] == "pack-quantized")
347+
348+
pack_factor = group_size // num_bits
349+
mask = (1 << num_bits) - 1
350+
unpacked = torch.zeros(
351+
(weight.shape[0], weight.shape[1] * pack_factor),
352+
device=weight.device,
353+
dtype=torch.int32,
354+
)
355+
for i in range(pack_factor):
356+
unpacked[:, i::pack_factor] = (weight >> (num_bits * i)) & mask
357+
# TODO: may need to unpad
358+
unpacked = unpacked - (mask + 1) // 2 # convert uint4 to int4 (shift scale)
359+
scale = scale.unsqueeze(2)
360+
unpacked = unpacked.to(torch.float32)
361+
unpacked = unpacked.reshape(-1, unpacked.shape[1] // group_size, group_size)
362+
dequantized = (unpacked * scale).reshape(-1, unpacked.shape[1] * group_size)
363+
return dequantized
364+
336365
if quant_method == "bitnet":
337366
for name in self.model_tensors.keys():
338367
if name.endswith(".weight_scale"):
@@ -371,6 +400,24 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
371400
".scales",
372401
)
373402
]
403+
elif quant_method == "compressed-tensors":
404+
weight_block_size = quant_config["config_groups"]["group_0"]["weights"]["group_size"]
405+
quant_config["weight_block_size"] = weight_block_size
406+
for name in self.model_tensors.keys():
407+
if name.endswith("_packed"):
408+
base_name = name.removesuffix("_packed")
409+
packed = self.model_tensors[base_name + "_packed"]
410+
scale = self.model_tensors[base_name + "_scale"]
411+
# TODO: use _shape for unpadding if necessary
412+
new_tensors[base_name] = lambda p=packed, s=scale: dequant_compressed_tensor(p(), s())
413+
tensors_to_remove += [
414+
base_name + n
415+
for n in (
416+
"_packed",
417+
"_scale",
418+
"_shape",
419+
)
420+
]
374421
else:
375422
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
376423

0 commit comments

Comments
 (0)