2323import torch .nn .functional as F
2424from torch import nn
2525
26- from tensorrt_llm ._utils import get_sm_version , is_sm_100f
26+ from tensorrt_llm ._utils import get_sm_version , is_device_integrated , is_sm_100f
2727from tensorrt_llm .logger import logger
2828from tensorrt_llm .quantization .functional import \
2929 preprocess_weights_for_mixed_gemm
3838 unswizzle_sf )
3939from ..linear import TensorParallelMode , load_weight_shard
4040from .interface import MoEWeightLoadingMode
41+ from .moe_load_balancer import advise_tensor_pageout
4142
4243# The declarations aligns with moe_kernels.h
4344# pack inputs into int64, e.g. 4 x bf16 input values
@@ -306,6 +307,20 @@ def load_expert_weights_to_dst(
306307 w3_w1_kargs ["allow_partial_loading" ] = allow_partial_loading
307308 if "allow_partial_loading" in w2_args :
308309 w2_kargs ["allow_partial_loading" ] = allow_partial_loading
310+
311+ def maybe_pageout_mmapped_cpu_weights (
312+ weight_tensors : List [object ]) -> None :
313+ # Integrated GPU systems share physical memory with CPU. After we
314+ # finish copying from mmapped CPU weights, proactively advising the
315+ # kernel to drop those pages reduces shared-memory pressure.
316+ if not is_device_integrated ():
317+ return
318+ for weight in weight_tensors :
319+ if (isinstance (weight , torch .Tensor )
320+ and weight .device .type == "cpu"
321+ and weight .is_contiguous ()):
322+ advise_tensor_pageout (weight )
323+
309324 # Multithread weight load is superseded by prefetch_files() in model_engine.py
310325 # Also, threading adds overhead in order to protect shuffle index cache with critical section.
311326 for local_slot_id , expert_id in enumerate (load_expert_ids ):
@@ -361,6 +376,7 @@ def load_expert_weights_to_dst(
361376 if weight is not None
362377 ]
363378 module ._add_raw_shared_weights_for_unmap (unmap_weights )
379+ maybe_pageout_mmapped_cpu_weights (unmap_weights )
364380
365381 if module .bias :
366382 self .load_expert_w3_w1_weight (
@@ -375,6 +391,7 @@ def load_expert_weights_to_dst(
375391 if weight is not None
376392 ]
377393 module ._add_raw_shared_weights_for_unmap (unmap_weights )
394+ maybe_pageout_mmapped_cpu_weights (unmap_weights )
378395
379396 def load_weights (self ,
380397 module : torch .nn .Module ,
0 commit comments