Skip to content

Commit eeb555e

Browse files
authored
chore: memoize weight shuffle index to speed up weight preproc in moe_backend=TRTLLM (NVIDIA#4826)
Signed-off-by: Anthony Chang <[email protected]>
1 parent 1b963c1 commit eeb555e

File tree

7 files changed

+99
-55
lines changed

7 files changed

+99
-55
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace Routing
3535
{
3636

3737
// The type of method in top-K routing, for use in torch custom op
38-
// Please keep this in sync with the counterpart defined in tensorrt_llm/_torch/modules/fused_moe.py
38+
// Please keep this in sync with the counterpart defined in tensorrt_llm/_torch/modules/fused_moe/routing.py
3939
enum class RoutingMethodType : int64_t
4040
{
4141
// Default: Softmax -> TopK

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 91 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
import threading
21
from abc import ABC, abstractmethod
3-
from typing import Dict, List, NamedTuple
2+
from typing import Dict, List, NamedTuple, Union
43

54
import torch
65
from torch import nn
76

87
from tensorrt_llm._utils import get_sm_version
98
from tensorrt_llm.quantization.utils.fp4_utils import (
109
float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices,
11-
get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices,
12-
shuffle_matrix_a, shuffle_matrix_sf_a)
10+
get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices)
1311

1412
from ..linear import TensorParallelMode, load_weight_shard
1513
from .interface import MoEWeightLoadingMode
@@ -80,12 +78,8 @@ def create_weights(self, module: torch.nn.Module, weight_dtype: torch.dtype,
8078

8179
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
8280
weight_loading_mode: MoEWeightLoadingMode):
83-
# Use multi-threading to load expert weights in parallel.
84-
# Even though CPython has global interpreter lock (GIL),
85-
# it's still faster to load weights in parallel because it can utilize
86-
# CPU memory bandwidth better.
87-
threads = []
88-
81+
# Multithread weight load is superseded by prefetch_files() in model_engine.py
82+
# Also, threading adds overhead in order to protect shuffle index cache with critical section.
8983
for local_slot_id, expert_id in enumerate(
9084
module.initial_local_expert_ids):
9185
# expert_idx is the local slot index of current rank
@@ -106,21 +100,11 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict],
106100
f"Unknown weight loading mode in MoE: {weight_loading_mode}"
107101
)
108102

109-
thread = threading.Thread(
110-
target=self.load_expert_w3_w1_weight,
111-
args=(module, w1_weight, w3_weight,
112-
module.w3_w1_weight.data[expert_idx]))
113-
thread.start()
114-
threads.append(thread)
115-
116-
thread = threading.Thread(target=self.load_expert_w2_weight,
117-
args=(module, w2_weight,
118-
module.w2_weight.data[expert_idx]))
119-
thread.start()
120-
threads.append(thread)
103+
self.load_expert_w3_w1_weight(module, w1_weight, w3_weight,
104+
module.w3_w1_weight.data[expert_idx])
121105

122-
for thread in threads:
123-
thread.join()
106+
self.load_expert_w2_weight(module, w2_weight,
107+
module.w2_weight.data[expert_idx])
124108

125109
self.load_quant_scales(module, weights)
126110
# Re-setup quant scales after loading weights as the tensors may have been modified.
@@ -1011,6 +995,53 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
1011995
weight_dtype = float4_sf_dtype
1012996
block_scales_dtype = torch.float8_e4m3fn
1013997

998+
# Cache the permute indices during weight loading to avoid recompute
999+
# This assumes the same input shape always results in the same permute indices
1000+
_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {}
1001+
1002+
def _maybe_get_cached_w3_w1_permute_indices(
1003+
self,
1004+
dst_w3_w1_weight: torch.Tensor,
1005+
epilogue_tile_m: int,
1006+
num_elts_per_sf: Union[None, int] = None) -> torch.Tensor:
1007+
if dst_w3_w1_weight.shape not in self._cache_permute_indices:
1008+
# Get permute indices and chain them together
1009+
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(
1010+
dst_w3_w1_weight)
1011+
if num_elts_per_sf is None:
1012+
permute1 = get_shuffle_matrix_a_row_indices(
1013+
dst_w3_w1_weight, epilogue_tile_m=epilogue_tile_m)
1014+
else:
1015+
permute1 = get_shuffle_matrix_sf_a_row_indices(
1016+
dst_w3_w1_weight,
1017+
epilogue_tile_m=epilogue_tile_m,
1018+
num_elts_per_sf=num_elts_per_sf)
1019+
# Memoize permute indices as recompute is **very** costly
1020+
self._cache_permute_indices[
1021+
dst_w3_w1_weight.shape] = permute0[permute1].to(
1022+
dst_w3_w1_weight.device)
1023+
permute_indices = self._cache_permute_indices[dst_w3_w1_weight.shape]
1024+
return permute_indices
1025+
1026+
def _maybe_get_cached_w2_permute_indices(
1027+
self,
1028+
dst_w2_weight: torch.Tensor,
1029+
epilogue_tile_m: int,
1030+
num_elts_per_sf: Union[None, int] = None) -> torch.Tensor:
1031+
if dst_w2_weight.shape not in self._cache_permute_indices:
1032+
if num_elts_per_sf is None:
1033+
permute_indices = (get_shuffle_matrix_a_row_indices(
1034+
dst_w2_weight, epilogue_tile_m).to(dst_w2_weight.device))
1035+
else:
1036+
permute_indices = (get_shuffle_matrix_sf_a_row_indices(
1037+
dst_w2_weight,
1038+
epilogue_tile_m=epilogue_tile_m,
1039+
num_elts_per_sf=num_elts_per_sf).to(dst_w2_weight.device))
1040+
# Memoize permute indices as recompute is **very** costly
1041+
self._cache_permute_indices[dst_w2_weight.shape] = permute_indices
1042+
permute_indices = self._cache_permute_indices[dst_w2_weight.shape]
1043+
return permute_indices
1044+
10141045
def create_weights(self, module: torch.nn.Module):
10151046
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
10161047
block_scales_vec_size = 1
@@ -1056,16 +1087,13 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
10561087
dst_w3_weight.copy_(w3_weight_shard.view(dst_w3_weight.dtype))
10571088
dst_w1_weight.copy_(w1_weight_shard.view(dst_w1_weight.dtype))
10581089

1059-
# Get permute indices and chain them together
1060-
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(
1061-
dst_w3_w1_weight)
1062-
permute1 = get_shuffle_matrix_a_row_indices(dst_w3_w1_weight,
1063-
epilogue_tile_m)
1064-
permute = permute0[permute1]
1090+
# Get permute indices
1091+
permute_indices = self._maybe_get_cached_w3_w1_permute_indices(
1092+
dst_w3_w1_weight, epilogue_tile_m)
10651093

10661094
# Shuffle the weight according to permute indices
10671095
processed_w31_weight_shard = torch.ops.trtllm.shuffle_matrix(
1068-
dst_w3_w1_weight, permute.to(dst_w3_w1_weight.device))
1096+
dst_w3_w1_weight, permute_indices.to(dst_w3_w1_weight.device))
10691097

10701098
# Copy the result into device buffer
10711099
dst_w3_w1_weight.copy_(processed_w31_weight_shard.view(
@@ -1085,8 +1113,14 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
10851113
# Keep weights in device buffer
10861114
dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype),
10871115
non_blocking=True)
1088-
# Get permuted result
1089-
processed_w2_weight = shuffle_matrix_a(dst_w2_weight, epilogue_tile_m)
1116+
# Get permuted indices
1117+
permute_indices = self._maybe_get_cached_w2_permute_indices(
1118+
dst_w2_weight, epilogue_tile_m)
1119+
1120+
# Shuffle the weight according to permute indices
1121+
processed_w2_weight = torch.ops.trtllm.shuffle_matrix(
1122+
dst_w2_weight, permute_indices.to(dst_w2_weight.device))
1123+
10901124
# Copy the result into device buffer
10911125
dst_w2_weight.copy_(processed_w2_weight.view(dst_w2_weight.dtype),
10921126
non_blocking=True)
@@ -1121,16 +1155,16 @@ def load_expert_w3_w1_weight_scale_nvfp4(
11211155
# trtllm-gen specific block scales preprocessing logics
11221156
epilogue_tile_m = 128 # FIXME
11231157

1124-
# Get permute indices and chain them together
1125-
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(
1126-
dst_w3_w1_weight_scale)
1127-
permute1 = get_shuffle_matrix_sf_a_row_indices(
1128-
dst_w3_w1_weight_scale.view(float4_sf_dtype), epilogue_tile_m, 16)
1129-
permute = permute0[permute1]
1158+
# Get permute indices
1159+
permute_indices = self._maybe_get_cached_w3_w1_permute_indices(
1160+
dst_w3_w1_weight_scale.view(float4_sf_dtype),
1161+
epilogue_tile_m,
1162+
num_elts_per_sf=16)
11301163

11311164
# Shuffle the weight according to permute indices
11321165
w3_w1_weight_scale = torch.ops.trtllm.shuffle_matrix(
1133-
dst_w3_w1_weight_scale.view(float4_sf_dtype), permute.cuda())
1166+
dst_w3_w1_weight_scale.view(float4_sf_dtype), permute_indices)
1167+
11341168
# Assert should only be removed during debugging
11351169
assert w3_w1_weight_scale.is_cuda, "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed"
11361170
# Interleave the weight.
@@ -1155,13 +1189,26 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
11551189

11561190
# trtllm-gen specific block scales preprocessing logics
11571191
epilogue_tile_m = 128 # FIXME: read from kernel
1192+
11581193
# Assert should only be removed during debugging
11591194
assert dst_w2_weight_scale.is_cuda, "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed"
1160-
# Interleave the weight and copy
1195+
1196+
# Get permute indices
1197+
permute_indices = self._maybe_get_cached_w2_permute_indices(
1198+
dst_w2_weight_scale.view(float4_sf_dtype),
1199+
epilogue_tile_m,
1200+
num_elts_per_sf=16)
1201+
1202+
# Shuffle the weight according to permute indices
1203+
w_shuffled = torch.ops.trtllm.shuffle_matrix(
1204+
dst_w2_weight_scale.view(dtype=float4_sf_dtype), permute_indices)
1205+
# Interleave the weight.
1206+
processed_w2_weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
1207+
w_shuffled)
1208+
# Copy the result into device buffer
11611209
dst_w2_weight_scale.copy_(
1162-
shuffle_matrix_sf_a(
1163-
dst_w2_weight_scale.view(float4_sf_dtype), epilogue_tile_m,
1164-
16).view(self.block_scales_dtype).reshape(orig_shape))
1210+
processed_w2_weight_scale.view(
1211+
self.block_scales_dtype).reshape(orig_shape))
11651212

11661213
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
11671214
super().load_quant_scales(module, weights)

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,15 +1221,9 @@ def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
12211221
@skip_pre_blackwell
12221222
@pytest.mark.parametrize(
12231223
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend",
1224-
[
1225-
(1, 1, 1, True, True, True, "CUTLASS"),
1226-
# TODO: enable TRTLLM backend
1227-
# (1, 1, 1, True, True, True, "TRTLLM"),
1228-
],
1229-
ids=[
1230-
"latency_moe_cutlass",
1231-
# "latency_moe_trtllm",
1232-
],
1224+
[(1, 1, 1, True, True, True, "CUTLASS"),
1225+
(1, 1, 1, False, True, True, "TRTLLM")],
1226+
ids=["latency_moe_cutlass", "latency_moe_trtllm"],
12331227
)
12341228
def test_nvfp4(
12351229
self,

tests/integration/test_lists/qa/examples_test_list.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[throughput_tp
472472
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
473473
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
474474
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass]
475+
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm]
475476
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
476477
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
477478
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_latency]

tests/integration/test_lists/qa/llm_sanity_test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[fp8kv=False-att
135135
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
136136
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
137137
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass]
138+
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm]
138139
accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype
139140

140141
# Pivot to Pytorch test cases.

tests/integration/test_lists/test-db/l0_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ l0_b200:
4141
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
4242
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=nvfp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
4343
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass]
44+
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm]
4445
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
4546
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]
4647
- test_e2e.py::test_ptq_quickstart_advanced_mtp[DeepSeek-V3-Lite-BF16-DeepSeek-V3-Lite/bf16]

tests/unittest/_torch/thop/test_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ def check_accuracy(a, b, atol, rtol, percent):
705705
"has_routing_bias": False,
706706
"routing_method_type": RoutingMethodType.Qwen3
707707
},
708-
id="Qwen3"),
708+
id="RoutingQwen3"),
709709
],
710710
)
711711
def test_moe_fp4(num_tokens, hidden_size, intermediate_size, routing_info):

0 commit comments

Comments
 (0)