Skip to content

Commit 160e4b8

Browse files
authored
feat: Faster weight processing (moe nvfp4) (#1412)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent e0c7019 commit 160e4b8

File tree

2 files changed

+120
-34
lines changed

2 files changed

+120
-34
lines changed

flashinfer/fused_moe/core.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from enum import IntEnum
1919
from pathlib import Path
2020
from types import SimpleNamespace
21-
from typing import Any, Dict, List, Optional, Tuple
21+
from typing import Any, Dict, List, Optional, Tuple, Union
2222

2323
import torch
2424

@@ -34,7 +34,13 @@
3434
from ..jit import env as jit_env
3535
from ..jit import gen_jit_spec, setup_cubin_loader, sm100a_nvcc_flags
3636
from ..jit.cutlass_gemm.generate_kernels import generate_gemm_operations
37-
from ..utils import _check_shape_dtype_device, register_custom_op, register_fake_op
37+
from ..utils import (
38+
_check_shape_dtype_device,
39+
get_shuffle_matrix_a_row_indices,
40+
get_shuffle_matrix_sf_a_row_indices,
41+
register_custom_op,
42+
register_fake_op,
43+
)
3844
from .utils import (
3945
get_last_power_of_2_num_tokens_buckets,
4046
last_positive_power_of_2,
@@ -69,6 +75,56 @@ class WeightLayout(IntEnum):
6975
BlockMajorK = 2
7076

7177

78+
def _maybe_get_cached_w3_w1_permute_indices(
79+
_cache_permute_indices,
80+
dst_w3_w1_weight: torch.Tensor,
81+
epilogue_tile_m: int,
82+
num_elts_per_sf: Union[None, int] = None,
83+
) -> torch.Tensor:
84+
if dst_w3_w1_weight.shape not in _cache_permute_indices:
85+
# Get permute indices and chain them together
86+
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight)
87+
if num_elts_per_sf is None:
88+
permute1 = get_shuffle_matrix_a_row_indices(
89+
dst_w3_w1_weight, epilogue_tile_m=epilogue_tile_m
90+
)
91+
else:
92+
permute1 = get_shuffle_matrix_sf_a_row_indices(
93+
dst_w3_w1_weight,
94+
epilogue_tile_m=epilogue_tile_m,
95+
num_elts_per_sf=num_elts_per_sf,
96+
)
97+
# Memoize permute indices as recompute is **very** costly
98+
_cache_permute_indices[dst_w3_w1_weight.shape] = permute0[permute1].to(
99+
dst_w3_w1_weight.device
100+
)
101+
permute_indices = _cache_permute_indices[dst_w3_w1_weight.shape]
102+
return permute_indices
103+
104+
105+
def _maybe_get_cached_w2_permute_indices(
106+
_cache_permute_indices,
107+
dst_w2_weight: torch.Tensor,
108+
epilogue_tile_m: int,
109+
num_elts_per_sf: Union[None, int] = None,
110+
) -> torch.Tensor:
111+
if dst_w2_weight.shape not in _cache_permute_indices:
112+
if num_elts_per_sf is None:
113+
permute_indices = get_shuffle_matrix_a_row_indices(
114+
dst_w2_weight, epilogue_tile_m
115+
).to(dst_w2_weight.device)
116+
else:
117+
permute_indices = get_shuffle_matrix_sf_a_row_indices(
118+
dst_w2_weight,
119+
epilogue_tile_m=epilogue_tile_m,
120+
num_elts_per_sf=num_elts_per_sf,
121+
).to(dst_w2_weight.device)
122+
# Memoize permute indices as recompute is **very** costly
123+
_cache_permute_indices[dst_w2_weight.shape] = permute_indices
124+
permute_indices = _cache_permute_indices[dst_w2_weight.shape]
125+
return permute_indices
126+
127+
72128
def get_reorder_rows_for_gated_act_gemm_row_indices(x) -> torch.Tensor:
73129
"""
74130
Reorders rows in the gemm/MOE_gemm weight matrix for min-latency

tests/test_trtllm_gen_fused_moe.py

Lines changed: 62 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from abc import ABC, abstractmethod
1818
from enum import IntEnum
19-
from typing import Literal
19+
from typing import Dict, Literal
2020

2121
import pytest
2222
import torch
@@ -30,15 +30,19 @@
3030
next_positive_power_of_2,
3131
reorder_rows_for_gated_act_gemm,
3232
shuffle_matrix_a,
33-
shuffle_matrix_sf_a,
3433
)
34+
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
3535
from flashinfer.fused_moe import (
3636
WeightLayout,
3737
convert_to_block_layout,
3838
trtllm_fp4_block_scale_moe,
3939
trtllm_fp8_block_scale_moe,
4040
trtllm_fp8_per_tensor_scale_moe,
4141
)
42+
from flashinfer.fused_moe.core import (
43+
_maybe_get_cached_w2_permute_indices,
44+
_maybe_get_cached_w3_w1_permute_indices,
45+
)
4246

4347

4448
def check_cuda(err):
@@ -386,50 +390,67 @@ def prepare_static_weights_for_kernel(
386390
num_experts, hidden_size, intermediate_size // 16
387391
) # fp8 scaling factors
388392

389-
# Reorder rows of W1 and scales for fused gated activation
390-
gemm1_weights_fp4_interleaved = []
391-
gemm1_scales_fp4_interleaved = []
392-
for i in range(num_experts):
393-
gemm1_weights_fp4_interleaved.append(
394-
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
395-
)
396-
gemm1_scales_fp4_interleaved.append(
397-
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
398-
)
399-
400-
# Stack weights and scales for all experts
401-
gemm1_weights_fp4_interleaved = torch.stack(
402-
gemm1_weights_fp4_interleaved
403-
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
404-
gemm1_scales_fp4_interleaved = torch.stack(
405-
gemm1_scales_fp4_interleaved
406-
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
407-
408-
# Shuffle weights and scaling factors for transposed mma output
393+
# Using cached permute index calculation can speed up weights preprocessing
409394
gemm1_weights_fp4_shuffled = []
410395
gemm1_scales_fp4_shuffled = []
411396
gemm2_weights_fp4_shuffled = []
412397
gemm2_scales_fp4_shuffled = []
413398
for i in range(num_experts):
399+
# Calculate the permute indices for the following:
400+
# 1. Reorder rows of W1 and scales for fused gated activation
401+
# 2. Shuffle weights and scaling factors for transposed mma output
402+
# for both w3_w1 and w2 weights and scale factors
403+
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
404+
self._cache_permute_indices,
405+
gemm1_weights_fp4[i].view(torch.uint8),
406+
epilogue_tile_m,
407+
)
414408
gemm1_weights_fp4_shuffled.append(
415-
shuffle_matrix_a(
416-
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
417-
)
409+
gemm1_weights_fp4[i]
410+
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
411+
.contiguous()
412+
)
413+
414+
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
415+
self._cache_permute_indices,
416+
gemm1_scales_linear_fp4[i].view(torch.uint8),
417+
epilogue_tile_m,
418+
num_elts_per_sf=16,
418419
)
419420
gemm1_scales_fp4_shuffled.append(
420-
shuffle_matrix_sf_a(
421-
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
421+
nvfp4_block_scale_interleave(
422+
gemm1_scales_linear_fp4[i]
423+
.view(torch.uint8)[
424+
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
425+
]
426+
.contiguous()
422427
)
423428
)
424429

430+
permute_indices = _maybe_get_cached_w2_permute_indices(
431+
self._cache_permute_indices,
432+
gemm2_weights_fp4[i].view(torch.uint8),
433+
epilogue_tile_m,
434+
)
425435
gemm2_weights_fp4_shuffled.append(
426-
shuffle_matrix_a(
427-
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
428-
)
436+
gemm2_weights_fp4[i]
437+
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
438+
.contiguous()
439+
)
440+
441+
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
442+
self._cache_permute_indices,
443+
gemm2_scales_linear_fp4[i].view(torch.uint8),
444+
epilogue_tile_m,
445+
num_elts_per_sf=16,
429446
)
430447
gemm2_scales_fp4_shuffled.append(
431-
shuffle_matrix_sf_a(
432-
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
448+
nvfp4_block_scale_interleave(
449+
gemm2_scales_linear_fp4[i]
450+
.view(torch.uint8)[
451+
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
452+
]
453+
.contiguous()
433454
)
434455
)
435456

@@ -1627,6 +1648,12 @@ def calculate_tile_tokens_dim(num_tokens: int, num_experts: int, top_k: int) ->
16271648
return tile_tokens_dim
16281649

16291650

1651+
@pytest.fixture(scope="module")
1652+
def cache_permute_indices():
1653+
_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {}
1654+
return _cache_permute_indices
1655+
1656+
16301657
@pytest.mark.parametrize("num_tokens", [1, 1024])
16311658
@pytest.mark.parametrize("hidden_size", [1024])
16321659
@pytest.mark.parametrize("intermediate_size", [1024, 768, 384])
@@ -1758,6 +1785,7 @@ def test_moe_quantization_classes(
17581785
moe_impl,
17591786
routing_config,
17601787
weight_processing,
1788+
cache_permute_indices,
17611789
):
17621790
"""
17631791
Test MoE implementations using separated quantization workflow.
@@ -1778,6 +1806,8 @@ def test_moe_quantization_classes(
17781806
f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}"
17791807
)
17801808

1809+
moe_impl._cache_permute_indices = cache_permute_indices
1810+
17811811
seed = 0
17821812
torch.random.manual_seed(seed)
17831813

0 commit comments

Comments
 (0)