Skip to content

Commit f4d10a7

Browse files
yzh119dbari
andauthored
bugfix: fix the enum/int type mismatch mentioned in #2507 (#2508)
<!-- .github/pull_request_template.md --> ## 📌 Description As mentioned in #2507, `trtllm_fp8_per_tensor_scale_moe` function would fail when passed integer `activation_type`. This PR fixes the type mismatch. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Public APIs updated: activation_type now accepts integer values (defaults adjusted to numeric activation codes). * Call sites and public function signatures aligned to use the numeric activation_type. * **Tests** * Test inputs updated to supply numeric activation_type values instead of enum members. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
1 parent 292f9be commit f4d10a7

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

flashinfer/fused_moe/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ def __init__(
989989
use_deepseek_fp8: bool,
990990
hidden_size: int,
991991
intermediate_size: int,
992-
activation_type: int = ActivationType.Swiglu,
992+
activation_type: int = ActivationType.Swiglu.value,
993993
use_shuffled_weight: bool = False,
994994
weight_layout: int = WeightLayout.MajorK,
995995
use_packed_weights: bool = False,
@@ -1422,7 +1422,7 @@ def trtllm_fp8_per_tensor_scale_moe_op(
14221422
routing_method_type: int = 0,
14231423
enable_pdl: Optional[bool] = None,
14241424
tune_max_num_tokens: int = 8192,
1425-
activation_type: ActivationType = ActivationType.Swiglu,
1425+
activation_type: int = ActivationType.Swiglu.value,
14261426
) -> torch.Tensor:
14271427
if enable_pdl is None:
14281428
enable_pdl = device_support_pdl(hidden_states.device)
@@ -1482,7 +1482,7 @@ def trtllm_fp8_per_tensor_scale_moe_op(
14821482
use_routing_scales_on_input=use_routing_scales_on_input,
14831483
routing_method_type=routing_method_type,
14841484
enable_pdl=enable_pdl,
1485-
activation_type=activation_type.value,
1485+
activation_type=activation_type,
14861486
)
14871487
# Call the C++ function
14881488
result = moe_op.trtllm_fp8_per_tensor_scale_moe(
@@ -1507,7 +1507,7 @@ def trtllm_fp8_per_tensor_scale_moe_op(
15071507
routing_method_type,
15081508
enable_pdl,
15091509
[-1, -1] if tactic == -1 else tactic,
1510-
activation_type.value,
1510+
activation_type,
15111511
)
15121512

15131513
return result

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2667,8 +2667,8 @@ def run_moe_test(
26672667
@pytest.mark.parametrize(
26682668
"activation_type",
26692669
[
2670-
pytest.param(ActivationType.Swiglu, id="Swiglu"),
2671-
pytest.param(ActivationType.Geglu, id="Geglu"),
2670+
pytest.param(ActivationType.Swiglu.value, id="Swiglu"),
2671+
pytest.param(ActivationType.Geglu.value, id="Geglu"),
26722672
],
26732673
)
26742674
def test_renormalize_routing(
@@ -2855,9 +2855,9 @@ def test_renormalize_routing(
28552855
@pytest.mark.parametrize(
28562856
"activation_type",
28572857
[
2858-
pytest.param(ActivationType.Swiglu, id="Swiglu"),
2859-
pytest.param(ActivationType.Geglu, id="Geglu"),
2860-
pytest.param(ActivationType.Relu2, id="Relu2"),
2858+
pytest.param(ActivationType.Swiglu.value, id="Swiglu"),
2859+
pytest.param(ActivationType.Geglu.value, id="Geglu"),
2860+
pytest.param(ActivationType.Relu2.value, id="Relu2"),
28612861
],
28622862
)
28632863
def test_deepseekv3_routing(
@@ -2931,8 +2931,8 @@ def test_deepseekv3_routing(
29312931
@pytest.mark.parametrize(
29322932
"activation_type",
29332933
[
2934-
pytest.param(ActivationType.Swiglu, id="Swiglu"),
2935-
pytest.param(ActivationType.Geglu, id="Geglu"),
2934+
pytest.param(ActivationType.Swiglu.value, id="Swiglu"),
2935+
pytest.param(ActivationType.Geglu.value, id="Geglu"),
29362936
],
29372937
)
29382938
def test_topk_routing(
@@ -3005,7 +3005,7 @@ def test_topk_routing(
30053005
@pytest.mark.parametrize(
30063006
"activation_type",
30073007
[
3008-
pytest.param(ActivationType.Swiglu, id="Swiglu"),
3008+
pytest.param(ActivationType.Swiglu.value, id="Swiglu"),
30093009
],
30103010
)
30113011
def test_llama4_routing(

0 commit comments

Comments
 (0)