@@ -1064,7 +1064,7 @@ def __init__(
1064
1064
self .allow_flashinfer = _nvfp4 .allow_flashinfer
1065
1065
self .use_marlin = _nvfp4 .use_marlin
1066
1066
self .flashinfer_moe_backend = None
1067
-
1067
+ self . _cache_permute_indices : dict [ torch . Size , torch . Tensor ] = {}
1068
1068
if self .allow_flashinfer :
1069
1069
self .flashinfer_moe_backend = get_flashinfer_moe_backend ()
1070
1070
logger .info_once (
@@ -1197,19 +1197,23 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
1197
1197
weight_loader = weight_loader )
1198
1198
layer .register_parameter ("w2_input_scale" , w2_input_scale )
1199
1199
1200
- def prepare_static_weight_layouts_for_trtllm_moe (
1200
+ def prepare_static_weights_for_trtllm_fp4_moe (
1201
1201
self ,
1202
- gemm1_weights : torch .Tensor ,
1203
- gemm2_weights : torch .Tensor ,
1204
- gemm1_scales_linear_fp4_bytes : torch .Tensor ,
1205
- gemm2_scales_linear_fp4_bytes : torch .Tensor ,
1206
- hidden_size : int ,
1207
- intermediate_size : int ,
1208
- num_experts : int ,
1209
- ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
1202
+ # args_dequant,
1203
+ # args,
1204
+ gemm1_weights ,
1205
+ gemm2_weights ,
1206
+ gemm1_scales_linear_fp4_bytes ,
1207
+ gemm2_scales_linear_fp4_bytes ,
1208
+ hidden_size ,
1209
+ intermediate_size ,
1210
+ num_experts ,
1211
+ ):
1212
+ from flashinfer import nvfp4_block_scale_interleave
1213
+ from flashinfer .fused_moe .core import (
1214
+ _maybe_get_cached_w2_permute_indices ,
1215
+ _maybe_get_cached_w3_w1_permute_indices )
1210
1216
"""Prepare quantized weights for kernel (done offline with weights)."""
1211
- from flashinfer import (reorder_rows_for_gated_act_gemm ,
1212
- shuffle_matrix_a , shuffle_matrix_sf_a )
1213
1217
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
1214
1218
1215
1219
# Convert quantized weights to proper formats
@@ -1227,48 +1231,54 @@ def prepare_static_weight_layouts_for_trtllm_moe(
1227
1231
intermediate_size //
1228
1232
16 ) # fp8 scaling factors
1229
1233
1230
- # Reorder rows of W1 and scales for fused gated activation
1231
- gemm1_weights_fp4_interleaved = []
1232
- gemm1_scales_fp4_interleaved = []
1233
- for i in range (num_experts ):
1234
- gemm1_weights_fp4_interleaved .append (
1235
- reorder_rows_for_gated_act_gemm (gemm1_weights_fp4 [i ].clone ()))
1236
- gemm1_scales_fp4_interleaved .append (
1237
- reorder_rows_for_gated_act_gemm (
1238
- gemm1_scales_linear_fp4 [i ].clone ()))
1239
-
1240
- # Stack weights and scales for all experts
1241
- gemm1_weights_fp4_interleaved = torch .stack (
1242
- gemm1_weights_fp4_interleaved ).reshape (num_experts ,
1243
- 2 * intermediate_size ,
1244
- hidden_size // 2 )
1245
- gemm1_scales_fp4_interleaved = torch .stack (
1246
- gemm1_scales_fp4_interleaved ).reshape (num_experts ,
1247
- 2 * intermediate_size ,
1248
- hidden_size // 16 )
1249
-
1250
- # Shuffle weights and scaling factors for transposed mma output
1251
1234
gemm1_weights_fp4_shuffled = []
1252
1235
gemm1_scales_fp4_shuffled = []
1253
1236
gemm2_weights_fp4_shuffled = []
1254
1237
gemm2_scales_fp4_shuffled = []
1255
1238
for i in range (num_experts ):
1256
- gemm1_weights_fp4_shuffled .append (
1257
- shuffle_matrix_a (
1258
- gemm1_weights_fp4_interleaved [i ].view (torch .uint8 ),
1259
- epilogue_tile_m ))
1239
+ # Calculate the permute indices for the following:
1240
+ # 1. Reorder rows of W1 and scales for fused gated activation
1241
+ # 2. Shuffle weights and scaling factors for transposed mma output
1242
+ # for both w3_w1 and w2 weights and scale factors
1243
+ permute_indices = _maybe_get_cached_w3_w1_permute_indices (
1244
+ self ._cache_permute_indices ,
1245
+ gemm1_weights_fp4 [i ].view (torch .uint8 ),
1246
+ epilogue_tile_m ,
1247
+ )
1248
+ gemm1_weights_fp4_shuffled .append (gemm1_weights_fp4 [i ].view (
1249
+ torch .uint8 )[permute_indices .to (
1250
+ gemm1_weights_fp4 .device )].contiguous ())
1251
+
1252
+ permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices (
1253
+ self ._cache_permute_indices ,
1254
+ gemm1_scales_linear_fp4 [i ].view (torch .uint8 ),
1255
+ epilogue_tile_m ,
1256
+ num_elts_per_sf = 16 ,
1257
+ )
1260
1258
gemm1_scales_fp4_shuffled .append (
1261
- shuffle_matrix_sf_a (
1262
- gemm1_scales_fp4_interleaved [i ].view (torch .uint8 ),
1263
- epilogue_tile_m ))
1264
-
1265
- gemm2_weights_fp4_shuffled .append (
1266
- shuffle_matrix_a (gemm2_weights_fp4 [i ].view (torch .uint8 ),
1267
- epilogue_tile_m ))
1259
+ nvfp4_block_scale_interleave (gemm1_scales_linear_fp4 [i ].view (
1260
+ torch .uint8 )[permute_sf_indices .to (
1261
+ gemm1_scales_linear_fp4 .device )].contiguous ()))
1262
+
1263
+ permute_indices = _maybe_get_cached_w2_permute_indices (
1264
+ self ._cache_permute_indices ,
1265
+ gemm2_weights_fp4 [i ].view (torch .uint8 ),
1266
+ epilogue_tile_m ,
1267
+ )
1268
+ gemm2_weights_fp4_shuffled .append (gemm2_weights_fp4 [i ].view (
1269
+ torch .uint8 )[permute_indices .to (
1270
+ gemm2_weights_fp4 .device )].contiguous ())
1271
+
1272
+ permute_sf_indices = _maybe_get_cached_w2_permute_indices (
1273
+ self ._cache_permute_indices ,
1274
+ gemm2_scales_linear_fp4 [i ].view (torch .uint8 ),
1275
+ epilogue_tile_m ,
1276
+ num_elts_per_sf = 16 ,
1277
+ )
1268
1278
gemm2_scales_fp4_shuffled .append (
1269
- shuffle_matrix_sf_a (
1270
- gemm2_scales_linear_fp4 [ i ]. view ( torch .uint8 ),
1271
- epilogue_tile_m ))
1279
+ nvfp4_block_scale_interleave ( gemm2_scales_linear_fp4 [ i ]. view (
1280
+ torch .uint8 )[ permute_sf_indices . to (
1281
+ gemm2_scales_linear_fp4 . device )]. contiguous () ))
1272
1282
1273
1283
# Stack weights for all experts
1274
1284
gemm1_weights_fp4_shuffled = torch .stack (gemm1_weights_fp4_shuffled )
@@ -1283,8 +1293,12 @@ def prepare_static_weight_layouts_for_trtllm_moe(
1283
1293
torch .stack (gemm2_scales_fp4_shuffled ).view (
1284
1294
torch .float8_e4m3fn ).reshape (num_experts , hidden_size ,
1285
1295
intermediate_size // 16 ))
1286
- return (gemm1_weights_fp4_shuffled , gemm1_scales_fp4_shuffled ,
1287
- gemm2_weights_fp4_shuffled , gemm2_scales_fp4_shuffled )
1296
+ return (
1297
+ gemm1_weights_fp4_shuffled ,
1298
+ gemm1_scales_fp4_shuffled ,
1299
+ gemm2_weights_fp4_shuffled ,
1300
+ gemm2_scales_fp4_shuffled ,
1301
+ )
1288
1302
1289
1303
def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
1290
1304
# GEMM 1 processing
@@ -1334,9 +1348,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1334
1348
if self .allow_flashinfer and \
1335
1349
self .flashinfer_moe_backend == FlashinferMoeBackend .TENSORRT_LLM :
1336
1350
# Prepare static weights for TRT-LLM kernel
1351
+ # alternate: prepare_static_weight_layouts_for_trtllm_moe
1337
1352
(gemm1_weights_fp4_shuffled , gemm1_scales_fp4_shuffled ,
1338
1353
gemm2_weights_fp4_shuffled , gemm2_scales_fp4_shuffled
1339
- ) = self .prepare_static_weight_layouts_for_trtllm_moe (
1354
+ ) = self .prepare_static_weights_for_trtllm_fp4_moe (
1340
1355
layer .w13_weight ,
1341
1356
layer .w2_weight ,
1342
1357
layer .w13_weight_scale ,
@@ -1345,6 +1360,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1345
1360
layer .w13_weight .size (- 2 ) // 2 , # intermediate_size
1346
1361
layer .w13_weight .size (0 ), # num_experts
1347
1362
)
1363
+ logger .debug_once ("Finished shuffling weights for TRT-LLM MOE" )
1348
1364
1349
1365
layer .gemm1_weights_fp4_shuffled = Parameter (
1350
1366
gemm1_weights_fp4_shuffled , requires_grad = False )
0 commit comments