Skip to content

Commit 19a3031

Browse files
[TRTLLM-10329][feat] Fix weight loading for Nemotron 3 models on DGX Spark (#11405)
Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
1 parent 052fe2f commit 19a3031

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def _split_mamba2_mixer_in_proj(w: torch.Tensor) -> torch.Tensor:
8080
elif "A" in key:
8181
w = split(weights[name], tp_size, tp_rank)
8282
w = w.to(torch.float32)
83-
w = -torch.exp(w)
83+
# Avoid extra temporaries: one fp32 cast, then in-place exp/neg.
84+
w.exp_()
85+
w.neg_()
8486
new_weights[key] = w
8587
elif "D" in key:
8688
w = split(weights[name], tp_size, tp_rank)

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch.nn.functional as F
2424
from torch import nn
2525

26-
from tensorrt_llm._utils import get_sm_version, is_sm_100f
26+
from tensorrt_llm._utils import get_sm_version, is_device_integrated, is_sm_100f
2727
from tensorrt_llm.logger import logger
2828
from tensorrt_llm.quantization.functional import \
2929
preprocess_weights_for_mixed_gemm
@@ -38,6 +38,7 @@
3838
unswizzle_sf)
3939
from ..linear import TensorParallelMode, load_weight_shard
4040
from .interface import MoEWeightLoadingMode
41+
from .moe_load_balancer import advise_tensor_pageout
4142

4243
# The declarations aligns with moe_kernels.h
4344
# pack inputs into int64, e.g. 4 x bf16 input values
@@ -306,6 +307,20 @@ def load_expert_weights_to_dst(
306307
w3_w1_kargs["allow_partial_loading"] = allow_partial_loading
307308
if "allow_partial_loading" in w2_args:
308309
w2_kargs["allow_partial_loading"] = allow_partial_loading
310+
311+
def maybe_pageout_mmapped_cpu_weights(
312+
weight_tensors: List[object]) -> None:
313+
# Integrated GPU systems share physical memory with CPU. After we
314+
# finish copying from mmapped CPU weights, proactively advising the
315+
# kernel to drop those pages reduces shared-memory pressure.
316+
if not is_device_integrated():
317+
return
318+
for weight in weight_tensors:
319+
if (isinstance(weight, torch.Tensor)
320+
and weight.device.type == "cpu"
321+
and weight.is_contiguous()):
322+
advise_tensor_pageout(weight)
323+
309324
# Multithread weight load is superseded by prefetch_files() in model_engine.py
310325
# Also, threading adds overhead in order to protect shuffle index cache with critical section.
311326
for local_slot_id, expert_id in enumerate(load_expert_ids):
@@ -361,6 +376,7 @@ def load_expert_weights_to_dst(
361376
if weight is not None
362377
]
363378
module._add_raw_shared_weights_for_unmap(unmap_weights)
379+
maybe_pageout_mmapped_cpu_weights(unmap_weights)
364380

365381
if module.bias:
366382
self.load_expert_w3_w1_weight(
@@ -375,6 +391,7 @@ def load_expert_weights_to_dst(
375391
if weight is not None
376392
]
377393
module._add_raw_shared_weights_for_unmap(unmap_weights)
394+
maybe_pageout_mmapped_cpu_weights(unmap_weights)
378395

379396
def load_weights(self,
380397
module: torch.nn.Module,

0 commit comments

Comments
 (0)