Skip to content

Commit 692d8f2

Browse files
[TRTLLM-9455][feat] support for new checkpoint (#10082)
Signed-off-by: binghanc <176802681+binghanc@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 3e0344a commit 692d8f2

File tree

1 file changed

+156
-16
lines changed

1 file changed

+156
-16
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 156 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from tqdm import tqdm
3838
from transformers import PretrainedConfig
3939

40+
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
4041
from tensorrt_llm._ipc_utils import can_access_peer
4142
from tensorrt_llm._utils import get_sm_version
4243
from tensorrt_llm.functional import PositionEmbeddingType
@@ -142,6 +143,44 @@ def __init__(self, model, is_draft_model: bool = False):
142143

143144
def load_weights(self, weights: Dict, skip_modules: List[str] = []):
144145

146+
def requantize_weight_with_new_scale(weight, weight_scale, old_scale_2,
147+
new_scale_2, device):
148+
"""
149+
Dequantize FP4 weights and requantize with a new scale.
150+
151+
Args:
152+
weight: FP4 quantized weight tensor 2D [,]
153+
weight_scale: FP8 per-block scaling factors
154+
old_scale_2: original global scale (amax/(448*6))
155+
new_scale_2: new global scale (amax/(448*6))
156+
device: target device for computation
157+
158+
Returns:
159+
(requantized_weight, new_weight_scale)
160+
"""
161+
# Remember original dtype of weight_scale
162+
original_scale_dtype = weight_scale.dtype
163+
original_scale_shape = weight_scale.shape
164+
165+
# Dequantize
166+
dequant_shape = (weight.shape[0], weight.shape[1] * 2)
167+
weight_dequant = torch.ops.tensorrt_llm.e2m1_and_ufp8sf_scale_to_float_v2(
168+
weight.contiguous(),
169+
weight_scale.flatten().view(
170+
fp4_utils.float4_sf_dtype).contiguous(), old_scale_2, 16, 1,
171+
True).to(dtype=torch.bfloat16).reshape(dequant_shape)
172+
173+
# Requantize using the new_scale_2
174+
weight_requant, weight_scale_requant = torch.ops.trtllm.fp4_quantize(
175+
weight_dequant.to(device),
176+
1.0 / new_scale_2.to(device),
177+
16, # scaling_vector_size
178+
False)
179+
180+
# Ensure the returned scale has the same dtype as the input scale
181+
return weight_requant.cpu(), weight_scale_requant.reshape(
182+
original_scale_shape).view(original_scale_dtype).cpu()
183+
145184
def rename_moe_weight(weights: Dict, rename_rules: Dict):
146185
result = {}
147186
for key, value in weights.items():
@@ -355,27 +394,128 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
355394
).view(*attn_module.v_b_proj_dequant.shape).to(
356395
attn_module.v_b_proj_dequant.dtype))
357396
elif names[-1] == "kv_a_proj_with_mqa":
358-
fused_a = weights[
359-
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
360-
if not is_lite:
361-
q_a_proj = weights[
362-
f"{'.'.join(names[:-1])}.q_a_proj.weight"][:]
363-
fused_a = torch.cat([q_a_proj, fused_a], dim=0)
364-
365-
if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights:
366-
fused_a_scale = weights[
367-
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"]
397+
nvfp4_fused_a = self.model_config.get_quant_config(
398+
).layer_quant_mode.has_nvfp4() and weights[
399+
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"].dtype == fp4_utils.float4_e2m1x2 and weights[
400+
f"{'.'.join(names[:-1])}.q_a_proj.weight"].dtype == fp4_utils.float4_e2m1x2
401+
if nvfp4_fused_a:
402+
########### input_scale
403+
kv_a_proj_with_mqa_input_scale = weights[
404+
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.input_scale"]
405+
if not is_lite:
406+
q_a_proj_input_scale = weights[
407+
f"{'.'.join(names[:-1])}.q_a_proj.input_scale"]
408+
assert kv_a_proj_with_mqa_input_scale == q_a_proj_input_scale, "kv_a_proj_with_mqa.input_scale and q_a_proj.input_scale should be the same"
409+
# modelopt ckpt stores amax/(448*6), convert to (448*6)/amax
410+
shared_input_scale = kv_a_proj_with_mqa_input_scale
411+
module.input_scale.data.copy_(1.0 / shared_input_scale)
412+
E2M1_MAX = 6.0
413+
module.inv_input_scale.data.copy_(module.input_scale /
414+
E2M1_MAX)
415+
########### weight_scale_2
416+
need_requant_kv_a_proj_with_mqa = False
417+
need_requant_q_a_proj = False
418+
kv_a_proj_with_mqa_scale_2 = weights[
419+
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_2"]
420+
shared_weight_scale_2 = kv_a_proj_with_mqa_scale_2
421+
if not is_lite:
422+
q_a_proj_scale_2 = weights[
423+
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_2"]
424+
if kv_a_proj_with_mqa_scale_2 < q_a_proj_scale_2:
425+
shared_weight_scale_2 = q_a_proj_scale_2
426+
need_requant_kv_a_proj_with_mqa = True
427+
elif q_a_proj_scale_2 < kv_a_proj_with_mqa_scale_2:
428+
need_requant_q_a_proj = True
429+
430+
########### alpha
431+
alpha = shared_input_scale.float(
432+
) * shared_weight_scale_2.float()
433+
module.alpha.data.copy_(alpha)
434+
module.scalar_alpha = alpha.item()
435+
436+
########### weights
437+
kv_a_proj_with_mqa = weights[
438+
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
439+
440+
if not is_lite:
441+
q_a_proj = weights[
442+
f"{'.'.join(names[:-1])}.q_a_proj.weight"][:]
443+
444+
########### weight_scale
445+
kv_a_proj_with_mqa_scale = weights[
446+
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale"][:]
447+
kv_a_proj_with_mqa_scale = torch.ops.trtllm.block_scale_interleave(
448+
kv_a_proj_with_mqa_scale.view(
449+
fp4_utils.float4_sf_dtype))
368450
if not is_lite:
369451
q_a_proj_scale = weights[
370-
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:]
371-
fused_a_scale = torch.cat(
372-
[q_a_proj_scale, fused_a_scale], dim=0)
452+
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale"][:]
453+
q_a_proj_scale = torch.ops.trtllm.block_scale_interleave(
454+
q_a_proj_scale.view(fp4_utils.float4_sf_dtype))
455+
456+
########### requantize
457+
if need_requant_kv_a_proj_with_mqa:
458+
# requant kv_a_proj_with_mqa
459+
kv_a_proj_with_mqa, kv_a_proj_with_mqa_scale = requantize_weight_with_new_scale(
460+
kv_a_proj_with_mqa,
461+
kv_a_proj_with_mqa_scale,
462+
kv_a_proj_with_mqa_scale_2,
463+
shared_weight_scale_2,
464+
device=module.weight.device,
465+
)
466+
if need_requant_q_a_proj:
467+
# requant q_a_proj
468+
q_a_proj, q_a_proj_scale = requantize_weight_with_new_scale(
469+
q_a_proj,
470+
q_a_proj_scale,
471+
q_a_proj_scale_2,
472+
shared_weight_scale_2,
473+
device=module.weight.device)
474+
475+
########### fuse and load weights
476+
if not is_lite:
477+
fused_a = torch.cat([q_a_proj, kv_a_proj_with_mqa],
478+
dim=0)
479+
else:
480+
fused_a = kv_a_proj_with_mqa
481+
482+
# For DeepseekV32: kv_a_proj_with_mqa is oversized
483+
# to include indexer k weights, which is filled in post_load_weights.
484+
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
373485

486+
########### fuse weight_scale
487+
if not is_lite:
488+
fused_a_scale = torch.cat(
489+
[q_a_proj_scale, kv_a_proj_with_mqa_scale],
490+
dim=0)
491+
else:
492+
fused_a_scale = kv_a_proj_with_mqa_scale
493+
# For DeepseekV32: kv_a_proj_with_mqa is oversized
494+
# to include indexer k weights, which is filled in post_load_weights.
374495
module.weight_scale.data[0:fused_a_scale.
375496
shape[0]].copy_(fused_a_scale)
376-
# For DeepseekV32: kv_a_proj_with_mqa is oversized
377-
# to include indexer k weights, which is filled in post_load_weights.
378-
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
497+
else:
498+
fused_a = weights[
499+
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
500+
if not is_lite:
501+
q_a_proj = weights[
502+
f"{'.'.join(names[:-1])}.q_a_proj.weight"][:]
503+
fused_a = torch.cat([q_a_proj, fused_a], dim=0)
504+
505+
if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights:
506+
fused_a_scale = weights[
507+
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"]
508+
if not is_lite:
509+
q_a_proj_scale = weights[
510+
f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:]
511+
fused_a_scale = torch.cat(
512+
[q_a_proj_scale, fused_a_scale], dim=0)
513+
514+
module.weight_scale.data[
515+
0:fused_a_scale.shape[0]].copy_(fused_a_scale)
516+
# For DeepseekV32: kv_a_proj_with_mqa is oversized
517+
# to include indexer k weights, which is filled in post_load_weights.
518+
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
379519
elif names[-1] in params_map:
380520
module_weights = []
381521
for new_name in params_map[names[-1]]:

0 commit comments

Comments
 (0)