Skip to content

Commit 228ad93

Browse files
yuxianqdominicshanshan
authored andcommitted
[None][fix] Fix dummy load format for key models. (NVIDIA#7993)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 78f2cbb commit 228ad93

File tree

15 files changed

+99
-66
lines changed

15 files changed

+99
-66
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,12 @@
4040
from transformers import PretrainedConfig
4141

4242
from tensorrt_llm._ipc_utils import can_access_peer
43-
from tensorrt_llm._utils import get_sm_version, is_sm_100f
43+
from tensorrt_llm._utils import get_sm_version
4444
from tensorrt_llm.functional import PositionEmbeddingType
4545
from tensorrt_llm.llmapi.utils import enable_llm_debug
4646
from tensorrt_llm.mapping import Mapping
4747
from tensorrt_llm.models.modeling_utils import QuantConfig
4848
from tensorrt_llm.quantization.mode import QuantAlgo
49-
from tensorrt_llm.quantization.utils.fp8_utils import (
50-
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
5149

5250
from ..attention_backend import AttentionMetadata
5351
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
@@ -1528,26 +1526,6 @@ def load_weights(self, weights: Dict):
15281526
weight_loader.load_weights(weights)
15291527

15301528
def post_load_weights(self):
1531-
all_named_modules = dict(self.model.named_modules())
1532-
for name, module in tqdm(all_named_modules.items(),
1533-
desc="Post loading weights"):
1534-
if len(module._parameters) <= 0 or name.startswith("draft_model"):
1535-
continue
1536-
else:
1537-
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
1538-
) and is_sm_100f() and hasattr(module, "weight_scale"):
1539-
weight, weight_scale = resmooth_to_fp8_e8m0(
1540-
module.weight, module.weight_scale)
1541-
transfromed_scale = transform_sf_into_required_layout(
1542-
weight_scale,
1543-
mn=weight.shape[0],
1544-
k=weight.shape[1],
1545-
recipe=(1, 128, 128),
1546-
is_sfa=False)
1547-
module.weight = nn.Parameter(weight, requires_grad=False)
1548-
module.weight_scale = nn.Parameter(transfromed_scale,
1549-
requires_grad=False)
1550-
15511529
for idx, layer in enumerate(
15521530
self.model.layers[:self.config.num_hidden_layers]):
15531531
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,7 @@ def load_weights(self, weights: Dict):
630630
else:
631631
self.load_hf_weights(weights)
632632

633+
def post_load_weights(self):
633634
for idx, layer in enumerate(
634635
self.model.block[:self.config.num_hidden_layers]):
635636
if idx == 0:

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -980,9 +980,7 @@ def __init__(
980980
):
981981
super().__init__(LlamaModel(model_config), model_config)
982982

983-
def load_weights(self, weights: Dict):
984-
super().load_weights(weights)
985-
983+
def post_load_weights(self):
986984
for idx, layer in enumerate(
987985
self.model.layers[:self.config.num_hidden_layers]):
988986
if idx == self.config.num_hidden_layers - 1:
@@ -1321,6 +1319,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
13211319
if had_mm_encoder:
13221320
self.mm_encoder = saved_mm_encoder
13231321

1322+
def post_load_weights(self):
13241323
for idx, layer in enumerate(
13251324
self.model.layers[:self.config.num_hidden_layers]):
13261325
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,9 @@
22

33
import torch
44
from torch import nn
5-
from tqdm import tqdm
65
from transformers import Qwen3Config
76

8-
from tensorrt_llm._utils import is_sm_100f
97
from tensorrt_llm.functional import PositionEmbeddingType
10-
from tensorrt_llm.quantization.utils.fp8_utils import (
11-
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
128

139
from ..attention_backend import AttentionMetadata
1410
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
@@ -225,24 +221,3 @@ def __init__(
225221
Qwen3Model(model_config),
226222
model_config,
227223
)
228-
229-
def post_load_weights(self):
230-
all_named_modules = dict(self.model.named_modules())
231-
for name, module in tqdm(all_named_modules.items(),
232-
desc="Post loading weights"):
233-
if len(module._parameters) <= 0 or name.startswith("draft_model"):
234-
continue
235-
else:
236-
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
237-
) and is_sm_100f() and hasattr(module, "weight_scale"):
238-
weight, weight_scale = resmooth_to_fp8_e8m0(
239-
module.weight, module.weight_scale)
240-
transfromed_scale = transform_sf_into_required_layout(
241-
weight_scale,
242-
mn=weight.shape[0],
243-
k=weight.shape[1],
244-
recipe=(1, 128, 128),
245-
is_sfa=False)
246-
module.weight = nn.Parameter(weight, requires_grad=False)
247-
module.weight_scale = nn.Parameter(transfromed_scale,
248-
requires_grad=False)

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from transformers import Qwen3MoeConfig
77

88
from tensorrt_llm._ipc_utils import can_access_peer
9-
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
10-
BaseWeightMapper
119

1210
from ..attention_backend import AttentionMetadata
1311
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
@@ -390,9 +388,7 @@ def __init__(
390388
)
391389
self.preload_weight_modules = self.model.preload_weight_modules
392390

393-
def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
394-
super().load_weights(weights, weight_mapper)
395-
391+
def post_load_weights(self):
396392
for idx, layer in enumerate(
397393
self.model.layers[:self.config.num_hidden_layers]):
398394
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def remove_weights(
143143
for mod in iter_modules(module, ignore_modules):
144144
mod._parameters.clear()
145145
mod._buffers.clear()
146+
mod._weights_removed = True
146147

147148

148149
def skip_forward(

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import torch.nn.functional as F
77
from torch import nn
88

9-
import tensorrt_llm.logger as trtllm_logger
109
from tensorrt_llm._utils import get_sm_version, is_sm_100f
10+
from tensorrt_llm.logger import logger
1111
from tensorrt_llm.quantization.functional import \
1212
preprocess_weights_for_mixed_gemm
1313
from tensorrt_llm.quantization.utils.fp4_utils import (
@@ -271,8 +271,6 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
271271
module.w2_bias.data if module.bias else None)
272272

273273
self.load_quant_scales(module, weights)
274-
# Re-setup quant scales after loading weights as the tensors may have been modified.
275-
self.setup_quant_scales(module)
276274

277275
if self.need_load_shared_weights(module):
278276
local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids(
@@ -323,7 +321,8 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
323321
module.initial_global_assignments)
324322

325323
def post_load_weights(self, module: torch.nn.Module):
326-
pass
324+
# Re-setup quant scales after loading weights as the tensors may have been modified.
325+
self.setup_quant_scales(module)
327326

328327
def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]):
329328
pass
@@ -722,14 +721,15 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
722721
if int(name.split(".")[0]) not in expert_ids:
723722
continue
724723
weight_name = name.replace("weight_scale_inv", "weight")
725-
trtllm_logger.logger.debug(f"Resmoothing {weight_name}")
724+
logger.debug(f"Resmoothing {weight_name}")
726725
weight = weights[weight_name][:]
727726
scale = weights[name][:]
728727
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
729728
weight, scale)
730729
super().load_weights(module, weights, weight_loading_mode)
731730

732731
def post_load_weights(self, module: torch.nn.Module):
732+
super().post_load_weights(module)
733733
if is_sm_100f():
734734
transfromed_w3_w1_scale = transform_sf_into_required_layout(
735735
module.quant_scales[0],

tensorrt_llm/_torch/modules/linear.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from tensorrt_llm.quantization.functional import \
2121
preprocess_weights_for_mixed_gemm
2222
from tensorrt_llm.quantization.mode import QuantAlgo
23+
from tensorrt_llm.quantization.utils.fp8_utils import (
24+
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
2325

2426
from ..._utils import is_sm_100f
2527
from ...models.modeling_utils import QuantConfig
@@ -715,6 +717,24 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
715717
copy_weight(module.weight, fused_weight)
716718
copy_weight(module.weight_scale, fused_scale)
717719

720+
def post_load_weights(self, module: Linear):
721+
super().post_load_weights(module)
722+
if is_sm_100f() and not (module.use_cute_dsl_blockscaling_mm
723+
or module.disable_deep_gemm):
724+
weight, weight_scale = resmooth_to_fp8_e8m0(module.weight,
725+
module.weight_scale)
726+
transfromed_scale = transform_sf_into_required_layout(
727+
weight_scale,
728+
mn=weight.shape[0],
729+
k=weight.shape[1],
730+
recipe=(1, 128, 128),
731+
is_sfa=False)
732+
module.weight = nn.Parameter(weight, requires_grad=False)
733+
module.weight_scale = nn.Parameter(
734+
transfromed_scale,
735+
requires_grad=False,
736+
)
737+
718738

719739
class NVFP4LinearMethod(LinearMethodBase):
720740

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,8 @@ def init_meta_tensor(t: torch.Tensor):
267267
f"No load support for load format: {load_format}")
268268

269269
for module in model.modules():
270-
if hasattr(module, 'post_load_weights'):
270+
if hasattr(module, 'post_load_weights') and not getattr(
271+
module, '_weights_removed', False):
271272
module.post_load_weights()
272273

273274
if isinstance(moe_load_balancer, MoeLoadBalancer):

tests/integration/defs/accuracy/accuracy_core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ def evaluate(self,
186186
extra_acc_spec: Optional[str] = None,
187187
extra_evaluator_kwargs: Optional[dict] = None,
188188
sampling_params: Optional[SamplingParams] = None,
189-
streaming: bool = False):
189+
streaming: bool = False,
190+
is_integration_test: bool = False):
190191
assert self.EVALUATOR_CLS is not None
191192

192193
if llm.args.speculative_config is None:
@@ -199,7 +200,8 @@ def evaluate(self,
199200
raise ValueError(
200201
f"Not recognized speculative_config: {llm.args.speculative_config}."
201202
)
202-
is_integration_test = os.getenv('INTEGRATION_TEST', '0') == '1'
203+
is_integration_test = is_integration_test or os.getenv(
204+
'INTEGRATION_TEST', '0') == '1'
203205

204206
if is_integration_test:
205207
logger.info(

0 commit comments

Comments
 (0)