|
37 | 37 | from tqdm import tqdm |
38 | 38 | from transformers import PretrainedConfig |
39 | 39 |
|
| 40 | +import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils |
40 | 41 | from tensorrt_llm._ipc_utils import can_access_peer |
41 | 42 | from tensorrt_llm._utils import get_sm_version |
42 | 43 | from tensorrt_llm.functional import PositionEmbeddingType |
@@ -142,6 +143,44 @@ def __init__(self, model, is_draft_model: bool = False): |
142 | 143 |
|
143 | 144 | def load_weights(self, weights: Dict, skip_modules: List[str] = []): |
144 | 145 |
|
| 146 | + def requantize_weight_with_new_scale(weight, weight_scale, old_scale_2, |
| 147 | + new_scale_2, device): |
| 148 | + """ |
| 149 | + Dequantize FP4 weights and requantize with a new scale. |
| 150 | +
|
| 151 | + Args: |
| 152 | + weight: FP4 quantized weight tensor 2D [,] |
| 153 | + weight_scale: FP8 per-block scaling factors |
| 154 | + old_scale_2: original global scale (amax/(448*6)) |
| 155 | + new_scale_2: new global scale (amax/(448*6)) |
| 156 | + device: target device for computation |
| 157 | +
|
| 158 | + Returns: |
| 159 | + (requantized_weight, new_weight_scale) |
| 160 | + """ |
| 161 | + # Remember original dtype of weight_scale |
| 162 | + original_scale_dtype = weight_scale.dtype |
| 163 | + original_scale_shape = weight_scale.shape |
| 164 | + |
| 165 | + # Dequantize |
| 166 | + dequant_shape = (weight.shape[0], weight.shape[1] * 2) |
| 167 | + weight_dequant = torch.ops.tensorrt_llm.e2m1_and_ufp8sf_scale_to_float_v2( |
| 168 | + weight.contiguous(), |
| 169 | + weight_scale.flatten().view( |
| 170 | + fp4_utils.float4_sf_dtype).contiguous(), old_scale_2, 16, 1, |
| 171 | + True).to(dtype=torch.bfloat16).reshape(dequant_shape) |
| 172 | + |
| 173 | + # Requantize using the new_scale_2 |
| 174 | + weight_requant, weight_scale_requant = torch.ops.trtllm.fp4_quantize( |
| 175 | + weight_dequant.to(device), |
| 176 | + 1.0 / new_scale_2.to(device), |
| 177 | + 16, # scaling_vector_size |
| 178 | + False) |
| 179 | + |
| 180 | + # Ensure the returned scale has the same dtype as the input scale |
| 181 | + return weight_requant.cpu(), weight_scale_requant.reshape( |
| 182 | + original_scale_shape).view(original_scale_dtype).cpu() |
| 183 | + |
145 | 184 | def rename_moe_weight(weights: Dict, rename_rules: Dict): |
146 | 185 | result = {} |
147 | 186 | for key, value in weights.items(): |
@@ -355,27 +394,128 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, |
355 | 394 | ).view(*attn_module.v_b_proj_dequant.shape).to( |
356 | 395 | attn_module.v_b_proj_dequant.dtype)) |
357 | 396 | elif names[-1] == "kv_a_proj_with_mqa": |
358 | | - fused_a = weights[ |
359 | | - f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:] |
360 | | - if not is_lite: |
361 | | - q_a_proj = weights[ |
362 | | - f"{'.'.join(names[:-1])}.q_a_proj.weight"][:] |
363 | | - fused_a = torch.cat([q_a_proj, fused_a], dim=0) |
364 | | - |
365 | | - if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights: |
366 | | - fused_a_scale = weights[ |
367 | | - f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"] |
| 397 | + nvfp4_fused_a = self.model_config.get_quant_config( |
| 398 | + ).layer_quant_mode.has_nvfp4() and weights[ |
| 399 | + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"].dtype == fp4_utils.float4_e2m1x2 and weights[ |
| 400 | + f"{'.'.join(names[:-1])}.q_a_proj.weight"].dtype == fp4_utils.float4_e2m1x2 |
| 401 | + if nvfp4_fused_a: |
| 402 | + ########### input_scale |
| 403 | + kv_a_proj_with_mqa_input_scale = weights[ |
| 404 | + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.input_scale"] |
| 405 | + if not is_lite: |
| 406 | + q_a_proj_input_scale = weights[ |
| 407 | + f"{'.'.join(names[:-1])}.q_a_proj.input_scale"] |
| 408 | + assert kv_a_proj_with_mqa_input_scale == q_a_proj_input_scale, "kv_a_proj_with_mqa.input_scale and q_a_proj.input_scale should be the same" |
| 409 | + # modelopt ckpt stores amax/(448*6), convert to (448*6)/amax |
| 410 | + shared_input_scale = kv_a_proj_with_mqa_input_scale |
| 411 | + module.input_scale.data.copy_(1.0 / shared_input_scale) |
| 412 | + E2M1_MAX = 6.0 |
| 413 | + module.inv_input_scale.data.copy_(module.input_scale / |
| 414 | + E2M1_MAX) |
| 415 | + ########### weight_scale_2 |
| 416 | + need_requant_kv_a_proj_with_mqa = False |
| 417 | + need_requant_q_a_proj = False |
| 418 | + kv_a_proj_with_mqa_scale_2 = weights[ |
| 419 | + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_2"] |
| 420 | + shared_weight_scale_2 = kv_a_proj_with_mqa_scale_2 |
| 421 | + if not is_lite: |
| 422 | + q_a_proj_scale_2 = weights[ |
| 423 | + f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_2"] |
| 424 | + if kv_a_proj_with_mqa_scale_2 < q_a_proj_scale_2: |
| 425 | + shared_weight_scale_2 = q_a_proj_scale_2 |
| 426 | + need_requant_kv_a_proj_with_mqa = True |
| 427 | + elif q_a_proj_scale_2 < kv_a_proj_with_mqa_scale_2: |
| 428 | + need_requant_q_a_proj = True |
| 429 | + |
| 430 | + ########### alpha |
| 431 | + alpha = shared_input_scale.float( |
| 432 | + ) * shared_weight_scale_2.float() |
| 433 | + module.alpha.data.copy_(alpha) |
| 434 | + module.scalar_alpha = alpha.item() |
| 435 | + |
| 436 | + ########### weights |
| 437 | + kv_a_proj_with_mqa = weights[ |
| 438 | + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:] |
| 439 | + |
| 440 | + if not is_lite: |
| 441 | + q_a_proj = weights[ |
| 442 | + f"{'.'.join(names[:-1])}.q_a_proj.weight"][:] |
| 443 | + |
| 444 | + ########### weight_scale |
| 445 | + kv_a_proj_with_mqa_scale = weights[ |
| 446 | + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale"][:] |
| 447 | + kv_a_proj_with_mqa_scale = torch.ops.trtllm.block_scale_interleave( |
| 448 | + kv_a_proj_with_mqa_scale.view( |
| 449 | + fp4_utils.float4_sf_dtype)) |
368 | 450 | if not is_lite: |
369 | 451 | q_a_proj_scale = weights[ |
370 | | - f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:] |
371 | | - fused_a_scale = torch.cat( |
372 | | - [q_a_proj_scale, fused_a_scale], dim=0) |
| 452 | + f"{'.'.join(names[:-1])}.q_a_proj.weight_scale"][:] |
| 453 | + q_a_proj_scale = torch.ops.trtllm.block_scale_interleave( |
| 454 | + q_a_proj_scale.view(fp4_utils.float4_sf_dtype)) |
| 455 | + |
| 456 | + ########### requantize |
| 457 | + if need_requant_kv_a_proj_with_mqa: |
| 458 | + # requant kv_a_proj_with_mqa |
| 459 | + kv_a_proj_with_mqa, kv_a_proj_with_mqa_scale = requantize_weight_with_new_scale( |
| 460 | + kv_a_proj_with_mqa, |
| 461 | + kv_a_proj_with_mqa_scale, |
| 462 | + kv_a_proj_with_mqa_scale_2, |
| 463 | + shared_weight_scale_2, |
| 464 | + device=module.weight.device, |
| 465 | + ) |
| 466 | + if need_requant_q_a_proj: |
| 467 | + # requant q_a_proj |
| 468 | + q_a_proj, q_a_proj_scale = requantize_weight_with_new_scale( |
| 469 | + q_a_proj, |
| 470 | + q_a_proj_scale, |
| 471 | + q_a_proj_scale_2, |
| 472 | + shared_weight_scale_2, |
| 473 | + device=module.weight.device) |
| 474 | + |
| 475 | + ########### fuse and load weights |
| 476 | + if not is_lite: |
| 477 | + fused_a = torch.cat([q_a_proj, kv_a_proj_with_mqa], |
| 478 | + dim=0) |
| 479 | + else: |
| 480 | + fused_a = kv_a_proj_with_mqa |
| 481 | + |
| 482 | + # For DeepseekV32: kv_a_proj_with_mqa is oversized |
| 483 | + # to include indexer k weights, which is filled in post_load_weights. |
| 484 | + module.weight.data[0:fused_a.shape[0]].copy_(fused_a) |
373 | 485 |
|
| 486 | + ########### fuse weight_scale |
| 487 | + if not is_lite: |
| 488 | + fused_a_scale = torch.cat( |
| 489 | + [q_a_proj_scale, kv_a_proj_with_mqa_scale], |
| 490 | + dim=0) |
| 491 | + else: |
| 492 | + fused_a_scale = kv_a_proj_with_mqa_scale |
| 493 | + # For DeepseekV32: kv_a_proj_with_mqa is oversized |
| 494 | + # to include indexer k weights, which is filled in post_load_weights. |
374 | 495 | module.weight_scale.data[0:fused_a_scale. |
375 | 496 | shape[0]].copy_(fused_a_scale) |
376 | | - # For DeepseekV32: kv_a_proj_with_mqa is oversized |
377 | | - # to include indexer k weights, which is filled in post_load_weights. |
378 | | - module.weight.data[0:fused_a.shape[0]].copy_(fused_a) |
| 497 | + else: |
| 498 | + fused_a = weights[ |
| 499 | + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:] |
| 500 | + if not is_lite: |
| 501 | + q_a_proj = weights[ |
| 502 | + f"{'.'.join(names[:-1])}.q_a_proj.weight"][:] |
| 503 | + fused_a = torch.cat([q_a_proj, fused_a], dim=0) |
| 504 | + |
| 505 | + if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights: |
| 506 | + fused_a_scale = weights[ |
| 507 | + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"] |
| 508 | + if not is_lite: |
| 509 | + q_a_proj_scale = weights[ |
| 510 | + f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:] |
| 511 | + fused_a_scale = torch.cat( |
| 512 | + [q_a_proj_scale, fused_a_scale], dim=0) |
| 513 | + |
| 514 | + module.weight_scale.data[ |
| 515 | + 0:fused_a_scale.shape[0]].copy_(fused_a_scale) |
| 516 | + # For DeepseekV32: kv_a_proj_with_mqa is oversized |
| 517 | + # to include indexer k weights, which is filled in post_load_weights. |
| 518 | + module.weight.data[0:fused_a.shape[0]].copy_(fused_a) |
379 | 519 | elif names[-1] in params_map: |
380 | 520 | module_weights = [] |
381 | 521 | for new_name in params_map[names[-1]]: |
|
0 commit comments