Skip to content

Commit 35ca4db

Browse files
authored
Support load from quantized weights
Dev quant
2 parents 8edf37b + 2d96f20 commit 35ca4db

File tree

9 files changed

+130
-31
lines changed

9 files changed

+130
-31
lines changed

configs/quantization/deepseekv3/awq_w_only_dsv3.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ model:
55
path: Deepseekv3-fp8-path
66
tokenizer_mode: fast
77
torch_dtype: torch.float8_e4m3fn
8+
block_wise_quant: True
89
calib:
910
name: pileval
1011
download: False

configs/quantization/deepseekv3/osplus_w_a_dsv3.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ model:
55
path: Deepseekv3-fp8-path
66
tokenizer_mode: fast
77
torch_dtype: torch.float8_e4m3fn
8+
block_wise_quant: True
89
calib:
910
name: pileval
1011
download: False

configs/quantization/deepseekv3/quarot_w_a_dsv3.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ model:
55
path: Deepseekv3-fp8-path
66
tokenizer_mode: fast
77
torch_dtype: torch.float8_e4m3fn
8+
block_wise_quant: True
89
quant:
910
method: Quarot
1011
weight:

configs/quantization/deepseekv3/rtn_w_a_dsv3.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ model:
55
path: Deepseekv3-fp8-path
66
tokenizer_mode: fast
77
torch_dtype: torch.float8_e4m3fn
8+
block_wise_quant: True
89
quant:
910
method: RTN
1011
weight:

configs/quantization/deepseekv3/rtn_w_only_dsv3.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ model:
55
path: Deepseekv3-fp8-path
66
tokenizer_mode: fast
77
torch_dtype: torch.float8_e4m3fn
8+
block_wise_quant: True
89
quant:
910
method: RTN
1011
weight:

configs/quantization/deepseekv3/smoothquant_w_a_dsv3.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ model:
55
path: Deepseekv3-fp8-path
66
tokenizer_mode: fast
77
torch_dtype: torch.float8_e4m3fn
8+
block_wise_quant: True
89
calib:
910
name: pileval
1011
download: False

llmc/compression/quantization/module_utils.py

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from loguru import logger
99
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
1010

11-
from .quant import FloatQuantizer
1211
from .utils import is_fp8_supported_gpu
1312

1413
if is_fp8_supported_gpu():
@@ -23,6 +22,10 @@
2322
'Using LLMC Quantizer implementation instead.'
2423
)
2524

25+
try:
26+
from vllm import _custom_ops as ops
27+
except ModuleNotFoundError:
28+
ops = None
2629

2730
try:
2831
import fast_hadamard_transform
@@ -34,8 +37,6 @@
3437
'If you need it, please install it firstly.'
3538
)
3639

37-
from .utils import calculate_zeros_width
38-
3940

4041
def block_wise_fp8_forward_func(x, w, w_scale, block_size, bias):
4142
x, scale = act_quant(x, block_size)
@@ -127,13 +128,105 @@ def new(cls, module):
127128
return new_module
128129

129130

131+
class VllmQuantLinearInt8(nn.Module):
132+
def __init__(self, in_features, out_features, bias=True):
133+
super().__init__()
134+
self.in_features = in_features
135+
self.out_features = out_features
136+
137+
self.register_buffer('weight', torch.empty((out_features, in_features), dtype=torch.int8))
138+
self.register_buffer('weight_scale', torch.empty((out_features, 1), dtype=torch.float32))
139+
140+
if bias:
141+
self.register_buffer('bias', torch.empty(out_features, dtype=torch.bfloat16))
142+
else:
143+
self.register_buffer('bias', None)
144+
145+
def act_quant_func(self, x):
146+
input_tensor_quant, input_tensor_scale, _ \
147+
= ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
148+
return input_tensor_quant, input_tensor_scale
149+
150+
def forward(self, input_tensor):
151+
input_tensor = input_tensor.squeeze(0)
152+
shape = (input_tensor.shape[0], self.weight.shape[0])
153+
dtype = input_tensor.dtype
154+
device = input_tensor.device
155+
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
156+
157+
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
158+
torch.ops._C.cutlass_scaled_mm(
159+
output_tensor,
160+
input_tensor_quant,
161+
self.weight.t(),
162+
input_tensor_scale,
163+
self.weight_scale.float(),
164+
self.bias,
165+
)
166+
return output_tensor.unsqueeze(0)
167+
168+
@classmethod
169+
@torch.no_grad()
170+
def new(cls, module):
171+
in_features = module.in_features
172+
out_features = module.out_features
173+
bias = module.bias is not None
174+
new_module = cls(in_features, out_features, bias)
175+
return new_module
176+
177+
178+
class VllmQuantLinearFp8(nn.Module):
179+
def __init__(self, in_features, out_features, bias=True):
180+
super().__init__()
181+
self.in_features = in_features
182+
self.out_features = out_features
183+
self.register_buffer('weight', torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn)) # noqa
184+
self.register_buffer('weight_scale', torch.empty((out_features, 1), dtype=torch.float32))
185+
if bias:
186+
self.register_buffer('bias', torch.empty(out_features, dtype=torch.bfloat16))
187+
else:
188+
self.register_buffer('bias', None)
189+
190+
def act_quant_func(self, x):
191+
input_tensor_quant, input_tensor_scale \
192+
= ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
193+
return input_tensor_quant, input_tensor_scale
194+
195+
def forward(self, input_tensor):
196+
input_tensor = input_tensor.squeeze(0)
197+
shape = (input_tensor.shape[0], self.weight.shape[0])
198+
dtype = input_tensor.dtype
199+
device = input_tensor.device
200+
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
201+
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
202+
torch.ops._C.cutlass_scaled_mm(
203+
output_tensor,
204+
input_tensor_quant,
205+
self.weight.t(),
206+
input_tensor_scale,
207+
self.weight_scale.float(),
208+
self.bias,
209+
)
210+
211+
return output_tensor.unsqueeze(0)
212+
213+
@classmethod
214+
@torch.no_grad()
215+
def new(cls, module):
216+
in_features = module.in_features
217+
out_features = module.out_features
218+
bias = module.bias is not None
219+
new_module = cls(in_features, out_features, bias)
220+
return new_module
221+
222+
130223
class LlmcFp8Linear(nn.Module):
131224
def __init__(self, in_features, out_features, bias, block_size):
132225
super().__init__()
133226
self.block_size = block_size
134227
self.in_features = in_features
135228
self.out_features = out_features
136-
if bias is not None:
229+
if bias:
137230
self.bias = nn.Parameter(torch.empty(out_features))
138231
else:
139232
self.register_parameter('bias', None)
@@ -172,7 +265,7 @@ def forward(self, x):
172265
def new(cls, module, block_size):
173266
in_features = module.in_features
174267
out_features = module.out_features
175-
bias = module.bias
268+
bias = module.bias is not None
176269
new_module = cls(in_features, out_features, bias, block_size)
177270
return new_module
178271

llmc/compression/quantization/utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,6 @@ def make_divisible(c, divisor):
55
return (c + divisor - 1) // divisor
66

77

8-
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
9-
if group_size >= 128:
10-
size_multiplier = 1
11-
elif group_size == 64:
12-
size_multiplier = 2
13-
elif group_size == 32:
14-
size_multiplier = 4
15-
else:
16-
raise NotImplementedError
17-
18-
base_width = make_divisible(in_features // group_size, pack_num)
19-
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
20-
return base_width
21-
22-
238
def is_fp8_supported_gpu():
249
if not torch.cuda.is_available():
2510
return False

llmc/models/base_model.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import os
55
from abc import ABCMeta, abstractmethod
66
from collections import defaultdict
7-
from functools import partial
87

98
import torch
109
import torch.nn as nn
@@ -16,7 +15,8 @@
1615

1716
from llmc.compression.quantization.module_utils import (
1817
_LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_, _TRANSFORMERS_LINEAR_TYPES_,
19-
_TRANSFORMERS_LN_TYPES_, LlmcFp8Linear)
18+
_TRANSFORMERS_LN_TYPES_, LlmcFp8Linear, VllmQuantLinearFp8,
19+
VllmQuantLinearInt8)
2020

2121

2222
class BaseModel(metaclass=ABCMeta):
@@ -27,7 +27,10 @@ def __init__(self, config, device_map=None, use_cache=False):
2727
self.tokenizer_mode = self.config.model.get('tokenizer_mode', 'fast')
2828
self.use_cpu_to_save_cuda_mem_for_catcher = self.config.model.get('use_cpu_to_save_cuda_mem_for_catcher', False) # noqa
2929
torch_dtype = self.config.model.torch_dtype
30-
self.torch_dtype = torch_dtype if torch_dtype == 'auto' else eval(torch_dtype)
30+
self.torch_dtype = torch_dtype if torch_dtype in ['auto'] else eval(torch_dtype)
31+
self.block_wise_quant = self.config.model.get('block_wise_quant', False)
32+
if self.block_wise_quant:
33+
assert self.torch_dtype == torch.float8_e4m3fn
3134
self.device_map = device_map
3235
self.use_cache = use_cache
3336
self.mm_model = None
@@ -199,20 +202,32 @@ def build_model(self):
199202
if hasattr(self.model_config, 'use_cache'):
200203
self.model_config.use_cache = False
201204
logger.info(f'self.model_config : {self.model_config}')
202-
if self.torch_dtype == torch.float8_e4m3fn:
205+
if self.torch_dtype in [torch.float8_e4m3fn, torch.int8]:
203206
with init_empty_weights():
204207
self.model = AutoModelForCausalLM.from_config(config=self.model_config,
205208
torch_dtype=torch.float16,
206209
trust_remote_code=True)
207210
self.find_blocks()
208-
self.fp8_block_size \
209-
= self.model_config.quantization_config['weight_block_size'][0]
211+
if self.torch_dtype == torch.float8_e4m3fn:
212+
if self.block_wise_quant:
213+
self.fp8_block_size \
214+
= self.model_config.quantization_config['weight_block_size'][0]
215+
params_dict = {'block_size': self.fp8_block_size}
216+
quant_linear_cls = LlmcFp8Linear
217+
else:
218+
params_dict = {}
219+
quant_linear_cls = VllmQuantLinearFp8
220+
elif self.torch_dtype == torch.int8:
221+
params_dict = {}
222+
quant_linear_cls = VllmQuantLinearInt8
223+
210224
for block_idx, block in enumerate(self.blocks):
211-
self.replace_module_block(LlmcFp8Linear,
225+
self.replace_module_block(quant_linear_cls,
212226
block,
213227
block_idx,
214-
{'block_size': self.fp8_block_size})
215-
self.load_fp8_weight()
228+
params_dict)
229+
230+
self.load_quant_weight()
216231

217232
logger.info(f'fp8 block size: {self.fp8_block_size}')
218233
else:
@@ -226,7 +241,7 @@ def build_model(self):
226241
)
227242
logger.info(f'self.model : {self.model}')
228243

229-
def load_fp8_weight(self):
244+
def load_quant_weight(self):
230245
state_dict = self.model.state_dict()
231246
model_index_file = os.path.join(self.model_path, 'model.safetensors.index.json')
232247

@@ -241,7 +256,7 @@ def load_fp8_weight(self):
241256

242257
for shard_path, tensor_names in shard_to_tensors.items():
243258
full_shard_path = os.path.join(self.model_path, shard_path)
244-
logger.info(f'Loading FP8 shard: {full_shard_path}')
259+
logger.info(f'Loading Quant shard: {full_shard_path}')
245260
with safe_open(full_shard_path, framework='pt', device='cpu') as f:
246261
for weight_name in tensor_names:
247262
tensor = f.get_tensor(weight_name)

0 commit comments

Comments
 (0)