1- import threading
21from abc import ABC , abstractmethod
3- from typing import Dict , List , NamedTuple
2+ from typing import Dict , List , NamedTuple , Union
43
54import torch
65from torch import nn
76
87from tensorrt_llm ._utils import get_sm_version
98from tensorrt_llm .quantization .utils .fp4_utils import (
109 float4_sf_dtype , get_reorder_rows_for_gated_act_gemm_row_indices ,
11- get_shuffle_matrix_a_row_indices , get_shuffle_matrix_sf_a_row_indices ,
12- shuffle_matrix_a , shuffle_matrix_sf_a )
10+ get_shuffle_matrix_a_row_indices , get_shuffle_matrix_sf_a_row_indices )
1311
1412from ..linear import TensorParallelMode , load_weight_shard
1513from .interface import MoEWeightLoadingMode
@@ -80,12 +78,8 @@ def create_weights(self, module: torch.nn.Module, weight_dtype: torch.dtype,
8078
8179 def load_weights (self , module : torch .nn .Module , weights : List [Dict ],
8280 weight_loading_mode : MoEWeightLoadingMode ):
83- # Use multi-threading to load expert weights in parallel.
84- # Even though CPython has global interpreter lock (GIL),
85- # it's still faster to load weights in parallel because it can utilize
86- # CPU memory bandwidth better.
87- threads = []
88-
81+ # Multithread weight load is superseded by prefetch_files() in model_engine.py
82+ # Also, threading adds overhead in order to protect shuffle index cache with critical section.
8983 for local_slot_id , expert_id in enumerate (
9084 module .initial_local_expert_ids ):
9185 # expert_idx is the local slot index of current rank
@@ -106,21 +100,11 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
106100 f"Unknown weight loading mode in MoE: { weight_loading_mode } "
107101 )
108102
109- thread = threading .Thread (
110- target = self .load_expert_w3_w1_weight ,
111- args = (module , w1_weight , w3_weight ,
112- module .w3_w1_weight .data [expert_idx ]))
113- thread .start ()
114- threads .append (thread )
115-
116- thread = threading .Thread (target = self .load_expert_w2_weight ,
117- args = (module , w2_weight ,
118- module .w2_weight .data [expert_idx ]))
119- thread .start ()
120- threads .append (thread )
103+ self .load_expert_w3_w1_weight (module , w1_weight , w3_weight ,
104+ module .w3_w1_weight .data [expert_idx ])
121105
122- for thread in threads :
123- thread . join ( )
106+ self . load_expert_w2_weight ( module , w2_weight ,
107+ module . w2_weight . data [ expert_idx ] )
124108
125109 self .load_quant_scales (module , weights )
126110 # Re-setup quant scales after loading weights as the tensors may have been modified.
@@ -1011,6 +995,53 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
1011995 weight_dtype = float4_sf_dtype
1012996 block_scales_dtype = torch .float8_e4m3fn
1013997
998+ # Cache the permute indices during weight loading to avoid recompute
999+ # This assumes the same input shape always results in the same permute indices
1000+ _cache_permute_indices : Dict [torch .Size , torch .Tensor ] = {}
1001+
1002+ def _maybe_get_cached_w3_w1_permute_indices (
1003+ self ,
1004+ dst_w3_w1_weight : torch .Tensor ,
1005+ epilogue_tile_m : int ,
1006+ num_elts_per_sf : Union [None , int ] = None ) -> torch .Tensor :
1007+ if dst_w3_w1_weight .shape not in self ._cache_permute_indices :
1008+ # Get permute indices and chain them together
1009+ permute0 = get_reorder_rows_for_gated_act_gemm_row_indices (
1010+ dst_w3_w1_weight )
1011+ if num_elts_per_sf is None :
1012+ permute1 = get_shuffle_matrix_a_row_indices (
1013+ dst_w3_w1_weight , epilogue_tile_m = epilogue_tile_m )
1014+ else :
1015+ permute1 = get_shuffle_matrix_sf_a_row_indices (
1016+ dst_w3_w1_weight ,
1017+ epilogue_tile_m = epilogue_tile_m ,
1018+ num_elts_per_sf = num_elts_per_sf )
1019+ # Memoize permute indices as recompute is **very** costly
1020+ self ._cache_permute_indices [
1021+ dst_w3_w1_weight .shape ] = permute0 [permute1 ].to (
1022+ dst_w3_w1_weight .device )
1023+ permute_indices = self ._cache_permute_indices [dst_w3_w1_weight .shape ]
1024+ return permute_indices
1025+
1026+ def _maybe_get_cached_w2_permute_indices (
1027+ self ,
1028+ dst_w2_weight : torch .Tensor ,
1029+ epilogue_tile_m : int ,
1030+ num_elts_per_sf : Union [None , int ] = None ) -> torch .Tensor :
1031+ if dst_w2_weight .shape not in self ._cache_permute_indices :
1032+ if num_elts_per_sf is None :
1033+ permute_indices = (get_shuffle_matrix_a_row_indices (
1034+ dst_w2_weight , epilogue_tile_m ).to (dst_w2_weight .device ))
1035+ else :
1036+ permute_indices = (get_shuffle_matrix_sf_a_row_indices (
1037+ dst_w2_weight ,
1038+ epilogue_tile_m = epilogue_tile_m ,
1039+ num_elts_per_sf = num_elts_per_sf ).to (dst_w2_weight .device ))
1040+ # Memoize permute indices as recompute is **very** costly
1041+ self ._cache_permute_indices [dst_w2_weight .shape ] = permute_indices
1042+ permute_indices = self ._cache_permute_indices [dst_w2_weight .shape ]
1043+ return permute_indices
1044+
10141045 def create_weights (self , module : torch .nn .Module ):
10151046 weight_vec_size = torch .iinfo (self .weight_dtype ).bits // 4
10161047 block_scales_vec_size = 1
@@ -1056,16 +1087,13 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
10561087 dst_w3_weight .copy_ (w3_weight_shard .view (dst_w3_weight .dtype ))
10571088 dst_w1_weight .copy_ (w1_weight_shard .view (dst_w1_weight .dtype ))
10581089
1059- # Get permute indices and chain them together
1060- permute0 = get_reorder_rows_for_gated_act_gemm_row_indices (
1061- dst_w3_w1_weight )
1062- permute1 = get_shuffle_matrix_a_row_indices (dst_w3_w1_weight ,
1063- epilogue_tile_m )
1064- permute = permute0 [permute1 ]
1090+ # Get permute indices
1091+ permute_indices = self ._maybe_get_cached_w3_w1_permute_indices (
1092+ dst_w3_w1_weight , epilogue_tile_m )
10651093
10661094 # Shuffle the weight according to permute indices
10671095 processed_w31_weight_shard = torch .ops .trtllm .shuffle_matrix (
1068- dst_w3_w1_weight , permute .to (dst_w3_w1_weight .device ))
1096+ dst_w3_w1_weight , permute_indices .to (dst_w3_w1_weight .device ))
10691097
10701098 # Copy the result into device buffer
10711099 dst_w3_w1_weight .copy_ (processed_w31_weight_shard .view (
@@ -1085,8 +1113,14 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
10851113 # Keep weights in device buffer
10861114 dst_w2_weight .copy_ (w2_weight_shard .view (dst_w2_weight .dtype ),
10871115 non_blocking = True )
1088- # Get permuted result
1089- processed_w2_weight = shuffle_matrix_a (dst_w2_weight , epilogue_tile_m )
1116+ # Get permuted indices
1117+ permute_indices = self ._maybe_get_cached_w2_permute_indices (
1118+ dst_w2_weight , epilogue_tile_m )
1119+
1120+ # Shuffle the weight according to permute indices
1121+ processed_w2_weight = torch .ops .trtllm .shuffle_matrix (
1122+ dst_w2_weight , permute_indices .to (dst_w2_weight .device ))
1123+
10901124 # Copy the result into device buffer
10911125 dst_w2_weight .copy_ (processed_w2_weight .view (dst_w2_weight .dtype ),
10921126 non_blocking = True )
@@ -1121,16 +1155,16 @@ def load_expert_w3_w1_weight_scale_nvfp4(
11211155 # trtllm-gen specific block scales preprocessing logics
11221156 epilogue_tile_m = 128 # FIXME
11231157
1124- # Get permute indices and chain them together
1125- permute0 = get_reorder_rows_for_gated_act_gemm_row_indices (
1126- dst_w3_w1_weight_scale )
1127- permute1 = get_shuffle_matrix_sf_a_row_indices (
1128- dst_w3_w1_weight_scale .view (float4_sf_dtype ), epilogue_tile_m , 16 )
1129- permute = permute0 [permute1 ]
1158+ # Get permute indices
1159+ permute_indices = self ._maybe_get_cached_w3_w1_permute_indices (
1160+ dst_w3_w1_weight_scale .view (float4_sf_dtype ),
1161+ epilogue_tile_m ,
1162+ num_elts_per_sf = 16 )
11301163
11311164 # Shuffle the weight according to permute indices
11321165 w3_w1_weight_scale = torch .ops .trtllm .shuffle_matrix (
1133- dst_w3_w1_weight_scale .view (float4_sf_dtype ), permute .cuda ())
1166+ dst_w3_w1_weight_scale .view (float4_sf_dtype ), permute_indices )
1167+
11341168 # Assert should only be removed during debugging
11351169 assert w3_w1_weight_scale .is_cuda , "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed"
11361170 # Interleave the weight.
@@ -1155,13 +1189,26 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
11551189
11561190 # trtllm-gen specific block scales preprocessing logics
11571191 epilogue_tile_m = 128 # FIXME: read from kernel
1192+
11581193 # Assert should only be removed during debugging
11591194 assert dst_w2_weight_scale .is_cuda , "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed"
1160- # Interleave the weight and copy
1195+
1196+ # Get permute indices
1197+ permute_indices = self ._maybe_get_cached_w2_permute_indices (
1198+ dst_w2_weight_scale .view (float4_sf_dtype ),
1199+ epilogue_tile_m ,
1200+ num_elts_per_sf = 16 )
1201+
1202+ # Shuffle the weight according to permute indices
1203+ w_shuffled = torch .ops .trtllm .shuffle_matrix (
1204+ dst_w2_weight_scale .view (dtype = float4_sf_dtype ), permute_indices )
1205+ # Interleave the weight.
1206+ processed_w2_weight_scale = torch .ops .tensorrt_llm .nvfp4_block_scale_interleave (
1207+ w_shuffled )
1208+ # Copy the result into device buffer
11611209 dst_w2_weight_scale .copy_ (
1162- shuffle_matrix_sf_a (
1163- dst_w2_weight_scale .view (float4_sf_dtype ), epilogue_tile_m ,
1164- 16 ).view (self .block_scales_dtype ).reshape (orig_shape ))
1210+ processed_w2_weight_scale .view (
1211+ self .block_scales_dtype ).reshape (orig_shape ))
11651212
11661213 def load_quant_scales (self , module : torch .nn .Module , weights : Dict ):
11671214 super ().load_quant_scales (module , weights )
0 commit comments