diff --git a/configs/quantization/deepseekv3/awq_w_only_dsv3.yml b/configs/quantization/deepseekv3/awq_w_only_dsv3.yml index d0d3976bc..12ba228bd 100755 --- a/configs/quantization/deepseekv3/awq_w_only_dsv3.yml +++ b/configs/quantization/deepseekv3/awq_w_only_dsv3.yml @@ -5,6 +5,7 @@ model: path: Deepseekv3-fp8-path tokenizer_mode: fast torch_dtype: torch.float8_e4m3fn + block_wise_quant: True calib: name: pileval download: False diff --git a/configs/quantization/deepseekv3/osplus_w_a_dsv3.yml b/configs/quantization/deepseekv3/osplus_w_a_dsv3.yml index 6437fdb63..cea88a552 100755 --- a/configs/quantization/deepseekv3/osplus_w_a_dsv3.yml +++ b/configs/quantization/deepseekv3/osplus_w_a_dsv3.yml @@ -5,6 +5,7 @@ model: path: Deepseekv3-fp8-path tokenizer_mode: fast torch_dtype: torch.float8_e4m3fn + block_wise_quant: True calib: name: pileval download: False diff --git a/configs/quantization/deepseekv3/quarot_w_a_dsv3.yml b/configs/quantization/deepseekv3/quarot_w_a_dsv3.yml index d444d5ef3..3710a85eb 100755 --- a/configs/quantization/deepseekv3/quarot_w_a_dsv3.yml +++ b/configs/quantization/deepseekv3/quarot_w_a_dsv3.yml @@ -5,6 +5,7 @@ model: path: Deepseekv3-fp8-path tokenizer_mode: fast torch_dtype: torch.float8_e4m3fn + block_wise_quant: True quant: method: Quarot weight: diff --git a/configs/quantization/deepseekv3/rtn_w_a_dsv3.yml b/configs/quantization/deepseekv3/rtn_w_a_dsv3.yml index e81e2dd4c..f8970c2a6 100755 --- a/configs/quantization/deepseekv3/rtn_w_a_dsv3.yml +++ b/configs/quantization/deepseekv3/rtn_w_a_dsv3.yml @@ -5,6 +5,7 @@ model: path: Deepseekv3-fp8-path tokenizer_mode: fast torch_dtype: torch.float8_e4m3fn + block_wise_quant: True quant: method: RTN weight: diff --git a/configs/quantization/deepseekv3/rtn_w_only_dsv3.yml b/configs/quantization/deepseekv3/rtn_w_only_dsv3.yml index 08f96dae8..417edb4e2 100755 --- a/configs/quantization/deepseekv3/rtn_w_only_dsv3.yml +++ b/configs/quantization/deepseekv3/rtn_w_only_dsv3.yml @@ -5,6 +5,7 @@ model: path: Deepseekv3-fp8-path tokenizer_mode: fast torch_dtype: torch.float8_e4m3fn + block_wise_quant: True quant: method: RTN weight: diff --git a/configs/quantization/deepseekv3/smoothquant_w_a_dsv3.yml b/configs/quantization/deepseekv3/smoothquant_w_a_dsv3.yml index 9c7e53685..a25a9dcc0 100755 --- a/configs/quantization/deepseekv3/smoothquant_w_a_dsv3.yml +++ b/configs/quantization/deepseekv3/smoothquant_w_a_dsv3.yml @@ -5,6 +5,7 @@ model: path: Deepseekv3-fp8-path tokenizer_mode: fast torch_dtype: torch.float8_e4m3fn + block_wise_quant: True calib: name: pileval download: False diff --git a/llmc/compression/quantization/module_utils.py b/llmc/compression/quantization/module_utils.py index 4251819ea..1c0e6e455 100755 --- a/llmc/compression/quantization/module_utils.py +++ b/llmc/compression/quantization/module_utils.py @@ -8,7 +8,6 @@ from loguru import logger from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from .quant import FloatQuantizer from .utils import is_fp8_supported_gpu if is_fp8_supported_gpu(): @@ -23,6 +22,10 @@ 'Using LLMC Quantizer implementation instead.' ) +try: + from vllm import _custom_ops as ops +except ModuleNotFoundError: + ops = None try: import fast_hadamard_transform @@ -34,8 +37,6 @@ 'If you need it, please install it firstly.' ) -from .utils import calculate_zeros_width - def block_wise_fp8_forward_func(x, w, w_scale, block_size, bias): x, scale = act_quant(x, block_size) @@ -127,13 +128,105 @@ def new(cls, module): return new_module +class VllmQuantLinearInt8(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.empty((out_features, in_features), dtype=torch.int8)) + self.register_buffer('weight_scale', torch.empty((out_features, 1), dtype=torch.float32)) + + if bias: + self.register_buffer('bias', torch.empty(out_features, dtype=torch.bfloat16)) + else: + self.register_buffer('bias', None) + + def act_quant_func(self, x): + input_tensor_quant, input_tensor_scale, _ \ + = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) + return input_tensor_quant, input_tensor_scale + + def forward(self, input_tensor): + input_tensor = input_tensor.squeeze(0) + shape = (input_tensor.shape[0], self.weight.shape[0]) + dtype = input_tensor.dtype + device = input_tensor.device + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + torch.ops._C.cutlass_scaled_mm( + output_tensor, + input_tensor_quant, + self.weight.t(), + input_tensor_scale, + self.weight_scale.float(), + self.bias, + ) + return output_tensor.unsqueeze(0) + + @classmethod + @torch.no_grad() + def new(cls, module): + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + new_module = cls(in_features, out_features, bias) + return new_module + + +class VllmQuantLinearFp8(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer('weight', torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn)) # noqa + self.register_buffer('weight_scale', torch.empty((out_features, 1), dtype=torch.float32)) + if bias: + self.register_buffer('bias', torch.empty(out_features, dtype=torch.bfloat16)) + else: + self.register_buffer('bias', None) + + def act_quant_func(self, x): + input_tensor_quant, input_tensor_scale \ + = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) + return input_tensor_quant, input_tensor_scale + + def forward(self, input_tensor): + input_tensor = input_tensor.squeeze(0) + shape = (input_tensor.shape[0], self.weight.shape[0]) + dtype = input_tensor.dtype + device = input_tensor.device + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + torch.ops._C.cutlass_scaled_mm( + output_tensor, + input_tensor_quant, + self.weight.t(), + input_tensor_scale, + self.weight_scale.float(), + self.bias, + ) + + return output_tensor.unsqueeze(0) + + @classmethod + @torch.no_grad() + def new(cls, module): + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + new_module = cls(in_features, out_features, bias) + return new_module + + class LlmcFp8Linear(nn.Module): def __init__(self, in_features, out_features, bias, block_size): super().__init__() self.block_size = block_size self.in_features = in_features self.out_features = out_features - if bias is not None: + if bias: self.bias = nn.Parameter(torch.empty(out_features)) else: self.register_parameter('bias', None) @@ -172,7 +265,7 @@ def forward(self, x): def new(cls, module, block_size): in_features = module.in_features out_features = module.out_features - bias = module.bias + bias = module.bias is not None new_module = cls(in_features, out_features, bias, block_size) return new_module diff --git a/llmc/compression/quantization/utils.py b/llmc/compression/quantization/utils.py index 5ec9195a1..588bf3e99 100755 --- a/llmc/compression/quantization/utils.py +++ b/llmc/compression/quantization/utils.py @@ -5,21 +5,6 @@ def make_divisible(c, divisor): return (c + divisor - 1) // divisor -def calculate_zeros_width(in_features, group_size=128, pack_num=8): - if group_size >= 128: - size_multiplier = 1 - elif group_size == 64: - size_multiplier = 2 - elif group_size == 32: - size_multiplier = 4 - else: - raise NotImplementedError - - base_width = make_divisible(in_features // group_size, pack_num) - base_width = make_divisible(base_width, size_multiplier) * size_multiplier - return base_width - - def is_fp8_supported_gpu(): if not torch.cuda.is_available(): return False diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index ebd6aa28d..6f6a563b2 100755 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -4,7 +4,6 @@ import os from abc import ABCMeta, abstractmethod from collections import defaultdict -from functools import partial import torch import torch.nn as nn @@ -16,7 +15,8 @@ from llmc.compression.quantization.module_utils import ( _LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_, _TRANSFORMERS_LINEAR_TYPES_, - _TRANSFORMERS_LN_TYPES_, LlmcFp8Linear) + _TRANSFORMERS_LN_TYPES_, LlmcFp8Linear, VllmQuantLinearFp8, + VllmQuantLinearInt8) class BaseModel(metaclass=ABCMeta): @@ -27,7 +27,10 @@ def __init__(self, config, device_map=None, use_cache=False): self.tokenizer_mode = self.config.model.get('tokenizer_mode', 'fast') self.use_cpu_to_save_cuda_mem_for_catcher = self.config.model.get('use_cpu_to_save_cuda_mem_for_catcher', False) # noqa torch_dtype = self.config.model.torch_dtype - self.torch_dtype = torch_dtype if torch_dtype == 'auto' else eval(torch_dtype) + self.torch_dtype = torch_dtype if torch_dtype in ['auto'] else eval(torch_dtype) + self.block_wise_quant = self.config.model.get('block_wise_quant', False) + if self.block_wise_quant: + assert self.torch_dtype == torch.float8_e4m3fn self.device_map = device_map self.use_cache = use_cache self.mm_model = None @@ -199,20 +202,32 @@ def build_model(self): if hasattr(self.model_config, 'use_cache'): self.model_config.use_cache = False logger.info(f'self.model_config : {self.model_config}') - if self.torch_dtype == torch.float8_e4m3fn: + if self.torch_dtype in [torch.float8_e4m3fn, torch.int8]: with init_empty_weights(): self.model = AutoModelForCausalLM.from_config(config=self.model_config, torch_dtype=torch.float16, trust_remote_code=True) self.find_blocks() - self.fp8_block_size \ - = self.model_config.quantization_config['weight_block_size'][0] + if self.torch_dtype == torch.float8_e4m3fn: + if self.block_wise_quant: + self.fp8_block_size \ + = self.model_config.quantization_config['weight_block_size'][0] + params_dict = {'block_size': self.fp8_block_size} + quant_linear_cls = LlmcFp8Linear + else: + params_dict = {} + quant_linear_cls = VllmQuantLinearFp8 + elif self.torch_dtype == torch.int8: + params_dict = {} + quant_linear_cls = VllmQuantLinearInt8 + for block_idx, block in enumerate(self.blocks): - self.replace_module_block(LlmcFp8Linear, + self.replace_module_block(quant_linear_cls, block, block_idx, - {'block_size': self.fp8_block_size}) - self.load_fp8_weight() + params_dict) + + self.load_quant_weight() logger.info(f'fp8 block size: {self.fp8_block_size}') else: @@ -226,7 +241,7 @@ def build_model(self): ) logger.info(f'self.model : {self.model}') - def load_fp8_weight(self): + def load_quant_weight(self): state_dict = self.model.state_dict() model_index_file = os.path.join(self.model_path, 'model.safetensors.index.json') @@ -241,7 +256,7 @@ def load_fp8_weight(self): for shard_path, tensor_names in shard_to_tensors.items(): full_shard_path = os.path.join(self.model_path, shard_path) - logger.info(f'Loading FP8 shard: {full_shard_path}') + logger.info(f'Loading Quant shard: {full_shard_path}') with safe_open(full_shard_path, framework='pt', device='cpu') as f: for weight_name in tensor_names: tensor = f.get_tensor(weight_name)