Skip to content

Commit 1ea6e3d

Browse files
[Kunlunxin]support DS R1/V3 w4a8int8 per-channel quantization (#226)
Co-authored-by: lishaohao <lsh862702688@163.com>
1 parent ea8540b commit 1ea6e3d

File tree

15 files changed

+759
-18
lines changed

15 files changed

+759
-18
lines changed

angelslim/compressor/quant/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .sample_func import EMASampler, MultiStepSampler # noqa: F401
2121
from .save import DeepSeekV3PTQSaveMulti # noqa: F401
2222
from .save import DeepSeekV3PTQSaveSingle # noqa: F401
23+
from .save import DeepSeekV3W4A8Int8Save # noqa: F401
2324
from .save import PTQOnlyScaleSave # noqa: F401
2425
from .save import PTQPTMSave # noqa: F401
2526
from .save import PTQSaveVllmHF # noqa: F401

angelslim/compressor/quant/core/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ def __init__(self, config, global_config=None):
103103
self.low_memory = config.quantization.low_memory
104104
self.quant_analyse = config.quantization.quant_analyse
105105
self.quant_vit = config.quantization.quant_vit
106+
elif "w4a8i8" in self.quant_algo:
107+
group_size = quantization_args.quant_method["group_size"]
108+
self.quant_algo_info = {
109+
"group_size": group_size,
110+
"ignore_layers": quantization_args.ignore_layers,
111+
}
112+
self.low_memory = config.quantization.low_memory
106113
elif "int8" in self.quant_algo:
107114
is_dynamic = "dynamic" if "dynamic" in self.quant_algo else "static"
108115
assert (

angelslim/compressor/quant/core/quant_func.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,3 +474,17 @@ def reduce_block_padding(input: torch.Tensor, block_sizes: dict, pad_value: floa
474474
padded_tensor = F.pad(padded_tensor, pad, value=pad_value)
475475

476476
return padded_tensor
477+
478+
479+
class Int8PerChannelQuantizer:
480+
"""Per-channel symmetric int8 quantizer."""
481+
482+
@torch.no_grad()
483+
def quantize(self, tensor: torch.Tensor):
484+
assert tensor.dtype == torch.bfloat16
485+
qmax = 127.0
486+
abs_max = torch.abs(tensor).max(dim=1, keepdim=True)[0]
487+
scale = abs_max / qmax
488+
quantized = torch.round(tensor / scale)
489+
quantized = torch.clamp(quantized, -qmax, qmax)
490+
return quantized.to(torch.int8), scale.to(torch.float32)

0 commit comments

Comments
 (0)