Skip to content

Commit 6adea79

Browse files
authored
fix nvfp4 weight_scale2 (Tencent#76)
1 parent 23f031e commit 6adea79

File tree

5 files changed

+88
-3
lines changed

5 files changed

+88
-3
lines changed

angelslim/compressor/quant/modules/nvfp4/nvfp4.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
import torch
1716

1817
from .....utils import print_info
@@ -89,7 +88,12 @@ def get_weights_scaling_factor(
8988
return q_per_block_scale
9089

9190
def post_process(self, sub_layer, name):
92-
weight_observer_amax = self.model.weight_scales_dict[name]
91+
# TODO:Fuse observer amax because TRT-LLM requires the qkv,
92+
# gate and up to share the weight_scale2
93+
weight_observer_amax, input_observer_amax = self.model.fuse_observer_amax(
94+
sub_layer, name
95+
)
96+
9397
weight_scale_2 = self.get_weights_scaling_factor_2(weight_observer_amax)
9498
self.model.weight_scales_dict_2[name] = weight_scale_2
9599

@@ -100,6 +104,5 @@ def post_process(self, sub_layer, name):
100104
)
101105
self.model.weight_scales_dict[name] = weight_scale
102106

103-
input_observer_amax = self.model.act_scales_dict[name]
104107
input_scale = self.get_activation_scaling_factor(input_observer_amax)
105108
self.model.act_scales_dict[name] = input_scale

angelslim/compressor/quant/ptq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ def _convert(self):
218218
self.ptq_hook.post_process()
219219

220220
quant_convert_module = self.quant_model.get_quant_convert_module()
221+
if "nvfp4" in self.quant_algo:
222+
self.quant_model.get_observer_values()
221223
# 2. insert qdq module
222224
for name, sub_layer in self.ptq_hook.quant_layers_dict.items():
223225
parent_layer, sub_name = find_parent_layer_and_sub_name(

angelslim/models/base_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
import re
1617
from abc import ABCMeta, abstractmethod
1718
from typing import Optional
@@ -170,6 +171,10 @@ def get_nvfp4_qdq_module(self, sub_layer, name):
170171
raise NotImplementedError
171172
return q_linear
172173

174+
def get_observer_values(self):
175+
self.weight_observer_amax_dict = copy.deepcopy(self.weight_scales_dict)
176+
self.input_observer_amax_dict = copy.deepcopy(self.act_scales_dict)
177+
173178
def get_kvcache_observer_layers_names(self, observe_names):
174179
names = ["self_attn.k_proj", "self_attn.v_proj"]
175180
return [

angelslim/models/llm/qwen.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,43 @@ def get_save_func(self):
9494
raise NotImplementedError(
9595
f"deploy_backend {self.deploy_backend} is not supported for saving."
9696
)
97+
98+
def fuse_observer_amax(self, sub_layer, name):
99+
if "q_proj" in name or "k_proj" in name or "v_proj" in name:
100+
prefix = name.rsplit(".", 1)[0]
101+
q_name = f"{prefix}.q_proj"
102+
k_name = f"{prefix}.k_proj"
103+
v_name = f"{prefix}.v_proj"
104+
105+
weight_scales = []
106+
for key in [q_name, k_name, v_name]:
107+
tensor = self.weight_observer_amax_dict[key]
108+
weight_scales.append(tensor)
109+
weight_observer_amax = max(weight_scales)
110+
111+
act_scales = []
112+
for key in [q_name, k_name, v_name]:
113+
tensor = self.input_observer_amax_dict[key]
114+
act_scales.append(tensor)
115+
input_observer_amax = max(act_scales)
116+
elif "gate_proj" in name or "up_proj" in name:
117+
prefix = name.rsplit(".", 1)[0]
118+
gate_name = f"{prefix}.gate_proj"
119+
up_name = f"{prefix}.up_proj"
120+
121+
weight_scales = []
122+
for key in [gate_name, up_name]:
123+
tensor = self.weight_observer_amax_dict[key]
124+
weight_scales.append(tensor)
125+
weight_observer_amax = max(weight_scales)
126+
127+
act_scales = []
128+
for key in [gate_name, up_name]:
129+
tensor = self.input_observer_amax_dict[key]
130+
act_scales.append(tensor)
131+
input_observer_amax = max(act_scales)
132+
else:
133+
weight_observer_amax = self.weight_observer_amax_dict[name]
134+
input_observer_amax = self.input_observer_amax_dict[name]
135+
136+
return weight_observer_amax, input_observer_amax
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Global configuration of pipeline
2+
global:
3+
save_path: ./output
4+
5+
# Simplified Configuration for LLM compression
6+
model:
7+
name: Qwen
8+
model_path: Qwen/Qwen3-235B-A22B
9+
trust_remote_code: true
10+
low_cpu_mem_usage: true
11+
use_cache: false
12+
torch_dtype: auto
13+
device_map: auto
14+
15+
# Compression configuration
16+
compression:
17+
name: PTQ
18+
quantization:
19+
name: nvfp4
20+
bits: 4
21+
quant_method:
22+
weight: "per-block"
23+
activation: "per-block"
24+
group_size: 16
25+
ignore_layers: # Skip quantization for these layers
26+
- "lm_head"
27+
- "model.embed_tokens"
28+
29+
# Dataset for calibration
30+
dataset:
31+
name: TextDataset
32+
data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl
33+
max_seq_length: 4096
34+
num_samples: 256
35+
batch_size: 1

0 commit comments

Comments
 (0)