Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/quantization/deepseekv3/awq_w_only_dsv3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ model:
path: Deepseekv3-fp8-path
tokenizer_mode: fast
torch_dtype: torch.float8_e4m3fn
block_wise_quant: True
calib:
name: pileval
download: False
Expand Down
1 change: 1 addition & 0 deletions configs/quantization/deepseekv3/osplus_w_a_dsv3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ model:
path: Deepseekv3-fp8-path
tokenizer_mode: fast
torch_dtype: torch.float8_e4m3fn
block_wise_quant: True
calib:
name: pileval
download: False
Expand Down
1 change: 1 addition & 0 deletions configs/quantization/deepseekv3/quarot_w_a_dsv3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ model:
path: Deepseekv3-fp8-path
tokenizer_mode: fast
torch_dtype: torch.float8_e4m3fn
block_wise_quant: True
quant:
method: Quarot
weight:
Expand Down
1 change: 1 addition & 0 deletions configs/quantization/deepseekv3/rtn_w_a_dsv3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ model:
path: Deepseekv3-fp8-path
tokenizer_mode: fast
torch_dtype: torch.float8_e4m3fn
block_wise_quant: True
quant:
method: RTN
weight:
Expand Down
1 change: 1 addition & 0 deletions configs/quantization/deepseekv3/rtn_w_only_dsv3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ model:
path: Deepseekv3-fp8-path
tokenizer_mode: fast
torch_dtype: torch.float8_e4m3fn
block_wise_quant: True
quant:
method: RTN
weight:
Expand Down
1 change: 1 addition & 0 deletions configs/quantization/deepseekv3/smoothquant_w_a_dsv3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ model:
path: Deepseekv3-fp8-path
tokenizer_mode: fast
torch_dtype: torch.float8_e4m3fn
block_wise_quant: True
calib:
name: pileval
download: False
Expand Down
103 changes: 98 additions & 5 deletions llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from loguru import logger
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

from .quant import FloatQuantizer
from .utils import is_fp8_supported_gpu

if is_fp8_supported_gpu():
Expand All @@ -23,6 +22,10 @@
'Using LLMC Quantizer implementation instead.'
)

try:
from vllm import _custom_ops as ops
except ModuleNotFoundError:
ops = None

try:
import fast_hadamard_transform
Expand All @@ -34,8 +37,6 @@
'If you need it, please install it firstly.'
)

from .utils import calculate_zeros_width


def block_wise_fp8_forward_func(x, w, w_scale, block_size, bias):
x, scale = act_quant(x, block_size)
Expand Down Expand Up @@ -127,13 +128,105 @@ def new(cls, module):
return new_module


class VllmQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features

self.register_buffer('weight', torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer('weight_scale', torch.empty((out_features, 1), dtype=torch.float32))

if bias:
self.register_buffer('bias', torch.empty(out_features, dtype=torch.bfloat16))
else:
self.register_buffer('bias', None)

def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale, _ \
= ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale

def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(
output_tensor,
input_tensor_quant,
self.weight.t(),
input_tensor_scale,
self.weight_scale.float(),
self.bias,
)
return output_tensor.unsqueeze(0)

@classmethod
@torch.no_grad()
def new(cls, module):
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
new_module = cls(in_features, out_features, bias)
return new_module
Comment on lines +131 to +175

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method in VllmQuantLinearInt8 assumes a batch size of 1 due to the use of input_tensor.squeeze(0) and output_tensor.unsqueeze(0). This will cause issues when processing batches of data with a size greater than 1. Reshape the input tensor to handle arbitrary batch sizes.

    def forward(self, input_tensor):
        input_shape = input_tensor.shape
        input_tensor = input_tensor.view(-1, self.in_features)

        output_tensor = torch.empty(
            (input_tensor.shape[0], self.out_features),
            dtype=input_tensor.dtype,
            device=input_tensor.device,
            requires_grad=False,
        )

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight.t(),
            input_tensor_scale,
            self.weight_scale.float(),
            self.bias,
        )
        return output_tensor.view(*input_shape[:-1], self.out_features)



class VllmQuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer('weight', torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn)) # noqa
self.register_buffer('weight_scale', torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer('bias', torch.empty(out_features, dtype=torch.bfloat16))
else:
self.register_buffer('bias', None)

def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale \
= ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
return input_tensor_quant, input_tensor_scale

def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(
output_tensor,
input_tensor_quant,
self.weight.t(),
input_tensor_scale,
self.weight_scale.float(),
self.bias,
)

return output_tensor.unsqueeze(0)

@classmethod
@torch.no_grad()
def new(cls, module):
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
new_module = cls(in_features, out_features, bias)
return new_module
Comment on lines +178 to +220

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method in VllmQuantLinearFp8 also assumes a batch size of 1 due to the use of input_tensor.squeeze(0) and output_tensor.unsqueeze(0). This needs to be corrected to properly support batched inputs by reshaping the input tensor.

    def forward(self, input_tensor):
        input_shape = input_tensor.shape
        input_tensor = input_tensor.view(-1, self.in_features)

        output_tensor = torch.empty(
            (input_tensor.shape[0], self.out_features),
            dtype=input_tensor.dtype,
            device=input_tensor.device,
            requires_grad=False,
        )

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight.t(),
            input_tensor_scale,
            self.weight_scale.float(),
            self.bias,
        )
        return output_tensor.view(*input_shape[:-1], self.out_features)



class LlmcFp8Linear(nn.Module):
def __init__(self, in_features, out_features, bias, block_size):
super().__init__()
self.block_size = block_size
self.in_features = in_features
self.out_features = out_features
if bias is not None:
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.register_parameter('bias', None)
Expand Down Expand Up @@ -172,7 +265,7 @@ def forward(self, x):
def new(cls, module, block_size):
in_features = module.in_features
out_features = module.out_features
bias = module.bias
bias = module.bias is not None
new_module = cls(in_features, out_features, bias, block_size)
return new_module

Expand Down
15 changes: 0 additions & 15 deletions llmc/compression/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,6 @@ def make_divisible(c, divisor):
return (c + divisor - 1) // divisor


def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
elif group_size == 64:
size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError

base_width = make_divisible(in_features // group_size, pack_num)
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width


def is_fp8_supported_gpu():
if not torch.cuda.is_available():
return False
Expand Down
37 changes: 26 additions & 11 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from functools import partial

import torch
import torch.nn as nn
Expand All @@ -16,7 +15,8 @@

from llmc.compression.quantization.module_utils import (
_LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_, _TRANSFORMERS_LINEAR_TYPES_,
_TRANSFORMERS_LN_TYPES_, LlmcFp8Linear)
_TRANSFORMERS_LN_TYPES_, LlmcFp8Linear, VllmQuantLinearFp8,
VllmQuantLinearInt8)


class BaseModel(metaclass=ABCMeta):
Expand All @@ -27,7 +27,10 @@ def __init__(self, config, device_map=None, use_cache=False):
self.tokenizer_mode = self.config.model.get('tokenizer_mode', 'fast')
self.use_cpu_to_save_cuda_mem_for_catcher = self.config.model.get('use_cpu_to_save_cuda_mem_for_catcher', False) # noqa
torch_dtype = self.config.model.torch_dtype
self.torch_dtype = torch_dtype if torch_dtype == 'auto' else eval(torch_dtype)
self.torch_dtype = torch_dtype if torch_dtype in ['auto'] else eval(torch_dtype)
self.block_wise_quant = self.config.model.get('block_wise_quant', False)
if self.block_wise_quant:
assert self.torch_dtype == torch.float8_e4m3fn
self.device_map = device_map
self.use_cache = use_cache
self.mm_model = None
Expand Down Expand Up @@ -199,20 +202,32 @@ def build_model(self):
if hasattr(self.model_config, 'use_cache'):
self.model_config.use_cache = False
logger.info(f'self.model_config : {self.model_config}')
if self.torch_dtype == torch.float8_e4m3fn:
if self.torch_dtype in [torch.float8_e4m3fn, torch.int8]:
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(config=self.model_config,
torch_dtype=torch.float16,
trust_remote_code=True)
self.find_blocks()
self.fp8_block_size \
= self.model_config.quantization_config['weight_block_size'][0]
if self.torch_dtype == torch.float8_e4m3fn:
if self.block_wise_quant:
self.fp8_block_size \
= self.model_config.quantization_config['weight_block_size'][0]
params_dict = {'block_size': self.fp8_block_size}
quant_linear_cls = LlmcFp8Linear
else:
params_dict = {}
quant_linear_cls = VllmQuantLinearFp8
elif self.torch_dtype == torch.int8:
params_dict = {}
quant_linear_cls = VllmQuantLinearInt8

for block_idx, block in enumerate(self.blocks):
self.replace_module_block(LlmcFp8Linear,
self.replace_module_block(quant_linear_cls,
block,
block_idx,
{'block_size': self.fp8_block_size})
self.load_fp8_weight()
params_dict)

self.load_quant_weight()

logger.info(f'fp8 block size: {self.fp8_block_size}')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logger.info call for fp8 block size can lead to an AttributeError if self.torch_dtype is not torch.float8_e4m3fn or if self.block_wise_quant is false, because self.fp8_block_size is only defined within that conditional block. Move the log statement inside the if self.block_wise_quant: block.

Suggested change
logger.info(f'fp8 block size: {self.fp8_block_size}')
# logger.info(f'fp8 block size: {self.fp8_block_size}')

else:
Expand All @@ -226,7 +241,7 @@ def build_model(self):
)
logger.info(f'self.model : {self.model}')

def load_fp8_weight(self):
def load_quant_weight(self):
state_dict = self.model.state_dict()
model_index_file = os.path.join(self.model_path, 'model.safetensors.index.json')

Expand All @@ -241,7 +256,7 @@ def load_fp8_weight(self):

for shard_path, tensor_names in shard_to_tensors.items():
full_shard_path = os.path.join(self.model_path, shard_path)
logger.info(f'Loading FP8 shard: {full_shard_path}')
logger.info(f'Loading Quant shard: {full_shard_path}')
with safe_open(full_shard_path, framework='pt', device='cpu') as f:
for weight_name in tensor_names:
tensor = f.get_tensor(weight_name)
Expand Down
Loading