16
16
17
17
from abc import ABC , abstractmethod
18
18
from enum import IntEnum
19
- from typing import Literal
19
+ from typing import Dict , Literal
20
20
21
21
import pytest
22
22
import torch
30
30
next_positive_power_of_2 ,
31
31
reorder_rows_for_gated_act_gemm ,
32
32
shuffle_matrix_a ,
33
- shuffle_matrix_sf_a ,
34
33
)
34
+ from flashinfer .fp4_quantization import nvfp4_block_scale_interleave
35
35
from flashinfer .fused_moe import (
36
36
WeightLayout ,
37
37
convert_to_block_layout ,
38
38
trtllm_fp4_block_scale_moe ,
39
39
trtllm_fp8_block_scale_moe ,
40
40
trtllm_fp8_per_tensor_scale_moe ,
41
41
)
42
+ from flashinfer .fused_moe .core import (
43
+ _maybe_get_cached_w2_permute_indices ,
44
+ _maybe_get_cached_w3_w1_permute_indices ,
45
+ )
42
46
43
47
44
48
def check_cuda (err ):
@@ -386,50 +390,67 @@ def prepare_static_weights_for_kernel(
386
390
num_experts , hidden_size , intermediate_size // 16
387
391
) # fp8 scaling factors
388
392
389
- # Reorder rows of W1 and scales for fused gated activation
390
- gemm1_weights_fp4_interleaved = []
391
- gemm1_scales_fp4_interleaved = []
392
- for i in range (num_experts ):
393
- gemm1_weights_fp4_interleaved .append (
394
- reorder_rows_for_gated_act_gemm (gemm1_weights_fp4 [i ].clone ())
395
- )
396
- gemm1_scales_fp4_interleaved .append (
397
- reorder_rows_for_gated_act_gemm (gemm1_scales_linear_fp4 [i ].clone ())
398
- )
399
-
400
- # Stack weights and scales for all experts
401
- gemm1_weights_fp4_interleaved = torch .stack (
402
- gemm1_weights_fp4_interleaved
403
- ).reshape (num_experts , 2 * intermediate_size , hidden_size // 2 )
404
- gemm1_scales_fp4_interleaved = torch .stack (
405
- gemm1_scales_fp4_interleaved
406
- ).reshape (num_experts , 2 * intermediate_size , hidden_size // 16 )
407
-
408
- # Shuffle weights and scaling factors for transposed mma output
393
+ # Using cached permute index calculation can speed up weights preprocessing
409
394
gemm1_weights_fp4_shuffled = []
410
395
gemm1_scales_fp4_shuffled = []
411
396
gemm2_weights_fp4_shuffled = []
412
397
gemm2_scales_fp4_shuffled = []
413
398
for i in range (num_experts ):
399
+ # Calculate the permute indices for the following:
400
+ # 1. Reorder rows of W1 and scales for fused gated activation
401
+ # 2. Shuffle weights and scaling factors for transposed mma output
402
+ # for both w3_w1 and w2 weights and scale factors
403
+ permute_indices = _maybe_get_cached_w3_w1_permute_indices (
404
+ self ._cache_permute_indices ,
405
+ gemm1_weights_fp4 [i ].view (torch .uint8 ),
406
+ epilogue_tile_m ,
407
+ )
414
408
gemm1_weights_fp4_shuffled .append (
415
- shuffle_matrix_a (
416
- gemm1_weights_fp4_interleaved [i ].view (torch .uint8 ), epilogue_tile_m
417
- )
409
+ gemm1_weights_fp4 [i ]
410
+ .view (torch .uint8 )[permute_indices .to (gemm1_weights_fp4 .device )]
411
+ .contiguous ()
412
+ )
413
+
414
+ permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices (
415
+ self ._cache_permute_indices ,
416
+ gemm1_scales_linear_fp4 [i ].view (torch .uint8 ),
417
+ epilogue_tile_m ,
418
+ num_elts_per_sf = 16 ,
418
419
)
419
420
gemm1_scales_fp4_shuffled .append (
420
- shuffle_matrix_sf_a (
421
- gemm1_scales_fp4_interleaved [i ].view (torch .uint8 ), epilogue_tile_m
421
+ nvfp4_block_scale_interleave (
422
+ gemm1_scales_linear_fp4 [i ]
423
+ .view (torch .uint8 )[
424
+ permute_sf_indices .to (gemm1_scales_linear_fp4 .device )
425
+ ]
426
+ .contiguous ()
422
427
)
423
428
)
424
429
430
+ permute_indices = _maybe_get_cached_w2_permute_indices (
431
+ self ._cache_permute_indices ,
432
+ gemm2_weights_fp4 [i ].view (torch .uint8 ),
433
+ epilogue_tile_m ,
434
+ )
425
435
gemm2_weights_fp4_shuffled .append (
426
- shuffle_matrix_a (
427
- gemm2_weights_fp4 [i ].view (torch .uint8 ), epilogue_tile_m
428
- )
436
+ gemm2_weights_fp4 [i ]
437
+ .view (torch .uint8 )[permute_indices .to (gemm2_weights_fp4 .device )]
438
+ .contiguous ()
439
+ )
440
+
441
+ permute_sf_indices = _maybe_get_cached_w2_permute_indices (
442
+ self ._cache_permute_indices ,
443
+ gemm2_scales_linear_fp4 [i ].view (torch .uint8 ),
444
+ epilogue_tile_m ,
445
+ num_elts_per_sf = 16 ,
429
446
)
430
447
gemm2_scales_fp4_shuffled .append (
431
- shuffle_matrix_sf_a (
432
- gemm2_scales_linear_fp4 [i ].view (torch .uint8 ), epilogue_tile_m
448
+ nvfp4_block_scale_interleave (
449
+ gemm2_scales_linear_fp4 [i ]
450
+ .view (torch .uint8 )[
451
+ permute_sf_indices .to (gemm2_scales_linear_fp4 .device )
452
+ ]
453
+ .contiguous ()
433
454
)
434
455
)
435
456
@@ -1627,6 +1648,12 @@ def calculate_tile_tokens_dim(num_tokens: int, num_experts: int, top_k: int) ->
1627
1648
return tile_tokens_dim
1628
1649
1629
1650
1651
+ @pytest .fixture (scope = "module" )
1652
+ def cache_permute_indices ():
1653
+ _cache_permute_indices : Dict [torch .Size , torch .Tensor ] = {}
1654
+ return _cache_permute_indices
1655
+
1656
+
1630
1657
@pytest .mark .parametrize ("num_tokens" , [1 , 1024 ])
1631
1658
@pytest .mark .parametrize ("hidden_size" , [1024 ])
1632
1659
@pytest .mark .parametrize ("intermediate_size" , [1024 , 768 , 384 ])
@@ -1758,6 +1785,7 @@ def test_moe_quantization_classes(
1758
1785
moe_impl ,
1759
1786
routing_config ,
1760
1787
weight_processing ,
1788
+ cache_permute_indices ,
1761
1789
):
1762
1790
"""
1763
1791
Test MoE implementations using separated quantization workflow.
@@ -1778,6 +1806,8 @@ def test_moe_quantization_classes(
1778
1806
f"Incompatible: { moe_impl .name } + { weight_processing ['use_shuffled_weight' ]} + { weight_processing ['layout' ]} "
1779
1807
)
1780
1808
1809
+ moe_impl ._cache_permute_indices = cache_permute_indices
1810
+
1781
1811
seed = 0
1782
1812
torch .random .manual_seed (seed )
1783
1813
0 commit comments