-
Notifications
You must be signed in to change notification settings - Fork 67
Dev quant #421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev quant #421
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,7 +8,6 @@ | |
| from loguru import logger | ||
| from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS | ||
|
|
||
| from .quant import FloatQuantizer | ||
| from .utils import is_fp8_supported_gpu | ||
|
|
||
| if is_fp8_supported_gpu(): | ||
|
|
@@ -23,6 +22,10 @@ | |
| 'Using LLMC Quantizer implementation instead.' | ||
| ) | ||
|
|
||
| try: | ||
| from vllm import _custom_ops as ops | ||
| except ModuleNotFoundError: | ||
| ops = None | ||
|
|
||
| try: | ||
| import fast_hadamard_transform | ||
|
|
@@ -34,8 +37,6 @@ | |
| 'If you need it, please install it firstly.' | ||
| ) | ||
|
|
||
| from .utils import calculate_zeros_width | ||
|
|
||
|
|
||
| def block_wise_fp8_forward_func(x, w, w_scale, block_size, bias): | ||
| x, scale = act_quant(x, block_size) | ||
|
|
@@ -127,13 +128,105 @@ def new(cls, module): | |
| return new_module | ||
|
|
||
|
|
||
| class VllmQuantLinearInt8(nn.Module): | ||
| def __init__(self, in_features, out_features, bias=True): | ||
| super().__init__() | ||
| self.in_features = in_features | ||
| self.out_features = out_features | ||
|
|
||
| self.register_buffer('weight', torch.empty((out_features, in_features), dtype=torch.int8)) | ||
| self.register_buffer('weight_scale', torch.empty((out_features, 1), dtype=torch.float32)) | ||
|
|
||
| if bias: | ||
| self.register_buffer('bias', torch.empty(out_features, dtype=torch.bfloat16)) | ||
| else: | ||
| self.register_buffer('bias', None) | ||
|
|
||
| def act_quant_func(self, x): | ||
| input_tensor_quant, input_tensor_scale, _ \ | ||
| = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) | ||
| return input_tensor_quant, input_tensor_scale | ||
|
|
||
| def forward(self, input_tensor): | ||
| input_tensor = input_tensor.squeeze(0) | ||
| shape = (input_tensor.shape[0], self.weight.shape[0]) | ||
| dtype = input_tensor.dtype | ||
| device = input_tensor.device | ||
| output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) | ||
|
|
||
| input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) | ||
| torch.ops._C.cutlass_scaled_mm( | ||
| output_tensor, | ||
| input_tensor_quant, | ||
| self.weight.t(), | ||
| input_tensor_scale, | ||
| self.weight_scale.float(), | ||
| self.bias, | ||
| ) | ||
| return output_tensor.unsqueeze(0) | ||
|
|
||
| @classmethod | ||
| @torch.no_grad() | ||
| def new(cls, module): | ||
| in_features = module.in_features | ||
| out_features = module.out_features | ||
| bias = module.bias is not None | ||
| new_module = cls(in_features, out_features, bias) | ||
| return new_module | ||
|
|
||
|
|
||
| class VllmQuantLinearFp8(nn.Module): | ||
| def __init__(self, in_features, out_features, bias=True): | ||
| super().__init__() | ||
| self.in_features = in_features | ||
| self.out_features = out_features | ||
| self.register_buffer('weight', torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn)) # noqa | ||
| self.register_buffer('weight_scale', torch.empty((out_features, 1), dtype=torch.float32)) | ||
| if bias: | ||
| self.register_buffer('bias', torch.empty(out_features, dtype=torch.bfloat16)) | ||
| else: | ||
| self.register_buffer('bias', None) | ||
|
|
||
| def act_quant_func(self, x): | ||
| input_tensor_quant, input_tensor_scale \ | ||
| = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) | ||
| return input_tensor_quant, input_tensor_scale | ||
|
|
||
| def forward(self, input_tensor): | ||
| input_tensor = input_tensor.squeeze(0) | ||
| shape = (input_tensor.shape[0], self.weight.shape[0]) | ||
| dtype = input_tensor.dtype | ||
| device = input_tensor.device | ||
| output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) | ||
| input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) | ||
| torch.ops._C.cutlass_scaled_mm( | ||
| output_tensor, | ||
| input_tensor_quant, | ||
| self.weight.t(), | ||
| input_tensor_scale, | ||
| self.weight_scale.float(), | ||
| self.bias, | ||
| ) | ||
|
|
||
| return output_tensor.unsqueeze(0) | ||
|
|
||
| @classmethod | ||
| @torch.no_grad() | ||
| def new(cls, module): | ||
| in_features = module.in_features | ||
| out_features = module.out_features | ||
| bias = module.bias is not None | ||
| new_module = cls(in_features, out_features, bias) | ||
| return new_module | ||
|
Comment on lines
+178
to
+220
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The def forward(self, input_tensor):
input_shape = input_tensor.shape
input_tensor = input_tensor.view(-1, self.in_features)
output_tensor = torch.empty(
(input_tensor.shape[0], self.out_features),
dtype=input_tensor.dtype,
device=input_tensor.device,
requires_grad=False,
)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(
output_tensor,
input_tensor_quant,
self.weight.t(),
input_tensor_scale,
self.weight_scale.float(),
self.bias,
)
return output_tensor.view(*input_shape[:-1], self.out_features) |
||
|
|
||
|
|
||
| class LlmcFp8Linear(nn.Module): | ||
| def __init__(self, in_features, out_features, bias, block_size): | ||
| super().__init__() | ||
| self.block_size = block_size | ||
| self.in_features = in_features | ||
| self.out_features = out_features | ||
| if bias is not None: | ||
| if bias: | ||
| self.bias = nn.Parameter(torch.empty(out_features)) | ||
| else: | ||
| self.register_parameter('bias', None) | ||
|
|
@@ -172,7 +265,7 @@ def forward(self, x): | |
| def new(cls, module, block_size): | ||
| in_features = module.in_features | ||
| out_features = module.out_features | ||
| bias = module.bias | ||
| bias = module.bias is not None | ||
| new_module = cls(in_features, out_features, bias, block_size) | ||
| return new_module | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,7 +4,6 @@ | |||||
| import os | ||||||
| from abc import ABCMeta, abstractmethod | ||||||
| from collections import defaultdict | ||||||
| from functools import partial | ||||||
|
|
||||||
| import torch | ||||||
| import torch.nn as nn | ||||||
|
|
@@ -16,7 +15,8 @@ | |||||
|
|
||||||
| from llmc.compression.quantization.module_utils import ( | ||||||
| _LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_, _TRANSFORMERS_LINEAR_TYPES_, | ||||||
| _TRANSFORMERS_LN_TYPES_, LlmcFp8Linear) | ||||||
| _TRANSFORMERS_LN_TYPES_, LlmcFp8Linear, VllmQuantLinearFp8, | ||||||
| VllmQuantLinearInt8) | ||||||
|
|
||||||
|
|
||||||
| class BaseModel(metaclass=ABCMeta): | ||||||
|
|
@@ -27,7 +27,10 @@ def __init__(self, config, device_map=None, use_cache=False): | |||||
| self.tokenizer_mode = self.config.model.get('tokenizer_mode', 'fast') | ||||||
| self.use_cpu_to_save_cuda_mem_for_catcher = self.config.model.get('use_cpu_to_save_cuda_mem_for_catcher', False) # noqa | ||||||
| torch_dtype = self.config.model.torch_dtype | ||||||
| self.torch_dtype = torch_dtype if torch_dtype == 'auto' else eval(torch_dtype) | ||||||
| self.torch_dtype = torch_dtype if torch_dtype in ['auto'] else eval(torch_dtype) | ||||||
| self.block_wise_quant = self.config.model.get('block_wise_quant', False) | ||||||
| if self.block_wise_quant: | ||||||
| assert self.torch_dtype == torch.float8_e4m3fn | ||||||
| self.device_map = device_map | ||||||
| self.use_cache = use_cache | ||||||
| self.mm_model = None | ||||||
|
|
@@ -199,20 +202,32 @@ def build_model(self): | |||||
| if hasattr(self.model_config, 'use_cache'): | ||||||
| self.model_config.use_cache = False | ||||||
| logger.info(f'self.model_config : {self.model_config}') | ||||||
| if self.torch_dtype == torch.float8_e4m3fn: | ||||||
| if self.torch_dtype in [torch.float8_e4m3fn, torch.int8]: | ||||||
| with init_empty_weights(): | ||||||
| self.model = AutoModelForCausalLM.from_config(config=self.model_config, | ||||||
| torch_dtype=torch.float16, | ||||||
| trust_remote_code=True) | ||||||
| self.find_blocks() | ||||||
| self.fp8_block_size \ | ||||||
| = self.model_config.quantization_config['weight_block_size'][0] | ||||||
| if self.torch_dtype == torch.float8_e4m3fn: | ||||||
| if self.block_wise_quant: | ||||||
| self.fp8_block_size \ | ||||||
| = self.model_config.quantization_config['weight_block_size'][0] | ||||||
| params_dict = {'block_size': self.fp8_block_size} | ||||||
| quant_linear_cls = LlmcFp8Linear | ||||||
| else: | ||||||
| params_dict = {} | ||||||
| quant_linear_cls = VllmQuantLinearFp8 | ||||||
| elif self.torch_dtype == torch.int8: | ||||||
| params_dict = {} | ||||||
| quant_linear_cls = VllmQuantLinearInt8 | ||||||
|
|
||||||
| for block_idx, block in enumerate(self.blocks): | ||||||
| self.replace_module_block(LlmcFp8Linear, | ||||||
| self.replace_module_block(quant_linear_cls, | ||||||
| block, | ||||||
| block_idx, | ||||||
| {'block_size': self.fp8_block_size}) | ||||||
| self.load_fp8_weight() | ||||||
| params_dict) | ||||||
|
|
||||||
| self.load_quant_weight() | ||||||
|
|
||||||
| logger.info(f'fp8 block size: {self.fp8_block_size}') | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| else: | ||||||
|
|
@@ -226,7 +241,7 @@ def build_model(self): | |||||
| ) | ||||||
| logger.info(f'self.model : {self.model}') | ||||||
|
|
||||||
| def load_fp8_weight(self): | ||||||
| def load_quant_weight(self): | ||||||
| state_dict = self.model.state_dict() | ||||||
| model_index_file = os.path.join(self.model_path, 'model.safetensors.index.json') | ||||||
|
|
||||||
|
|
@@ -241,7 +256,7 @@ def load_fp8_weight(self): | |||||
|
|
||||||
| for shard_path, tensor_names in shard_to_tensors.items(): | ||||||
| full_shard_path = os.path.join(self.model_path, shard_path) | ||||||
| logger.info(f'Loading FP8 shard: {full_shard_path}') | ||||||
| logger.info(f'Loading Quant shard: {full_shard_path}') | ||||||
| with safe_open(full_shard_path, framework='pt', device='cpu') as f: | ||||||
| for weight_name in tensor_names: | ||||||
| tensor = f.get_tensor(weight_name) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
forwardmethod inVllmQuantLinearInt8assumes a batch size of 1 due to the use ofinput_tensor.squeeze(0)andoutput_tensor.unsqueeze(0). This will cause issues when processing batches of data with a size greater than 1. Reshape the input tensor to handle arbitrary batch sizes.