Skip to content

Commit f5f1c55

Browse files
committed
fix test_trtllm_gen_fused_moe.py
Signed-off-by: jiahanc <[email protected]>
1 parent cdd6389 commit f5f1c55

File tree

3 files changed

+4
-36
lines changed

3 files changed

+4
-36
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ void TrtllmGenBatchedGemmRunner::run(
169169
auto const configs = bmm.getBatchedGemmConfigs();
170170

171171
auto const& config = configs[configIndex];
172-
172+
// std::cout << "Running GEMM with config: " << config.mFunctionName << std::endl;
173173
FLASHINFER_CHECK(numBatches > 0, "Batched GEMM requires numBatches > 0");
174174
if (!mOptions.staticBatch) {
175175
FLASHINFER_CHECK(totalNumPaddedTokens,

flashinfer/fused_moe/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,7 @@ def trtllm_fp4_block_scale_moe_op(
13451345
if hidden_states_scale is not None:
13461346
inputs.append(hidden_states_scale)
13471347

1348+
print(f"fp4 block scale moe tunning start")
13481349
_, tactic = tuner.choose_one(
13491350
"flashinfer::trtllm_fp4_block_scale_moe",
13501351
[moe_runner],
@@ -1373,7 +1374,7 @@ def trtllm_fp4_block_scale_moe_op(
13731374
do_finalize=do_finalize,
13741375
gated_act_type=gated_act_type,
13751376
)
1376-
1377+
print(f"fp4 block scale moe tunning end with tactic {tactic}")
13771378
# Call the C++ function for block scale MoE
13781379
output = moe_op.trtllm_fp4_block_scale_moe(
13791380
routing_logits,

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from abc import ABC, abstractmethod
1818
from enum import IntEnum
1919
from typing import Dict
20-
2120
import pytest
2221
import torch
2322
from cuda.bindings import runtime
@@ -1839,7 +1838,7 @@ def cache_permute_indices():
18391838

18401839
@pytest.mark.parametrize("num_tokens", [1, 8, 1024])
18411840
@pytest.mark.parametrize("hidden_size", [1024, 8192])
1842-
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 384])
1841+
@pytest.mark.parametrize("intermediate_size", [384, 768, 1024, 2048])
18431842
@pytest.mark.parametrize(
18441843
"moe_impl",
18451844
[
@@ -2244,35 +2243,3 @@ def test_moe_quantization_classes(
22442243
rtol=tolerances["rtol"],
22452244
percent=tolerances["percent"],
22462245
)
2247-
2248-
2249-
if __name__ == "__main__":
2250-
# pytest.main([__file__, "-v"])
2251-
routing_config = {
2252-
"num_experts": 256,
2253-
"top_k": 8,
2254-
"padding": 8,
2255-
"n_groups": 8,
2256-
"top_k_groups": 4,
2257-
"routed_scaling": 2.5,
2258-
"has_routing_bias": True,
2259-
"routing_method_type": RoutingMethodType.DeepSeekV3,
2260-
"compatible_moe_impls": [
2261-
FP8BlockScaleMoe,
2262-
],
2263-
}
2264-
weight_processing = {
2265-
"use_shuffled_weight": False,
2266-
"layout": WeightLayout.MajorK,
2267-
"compatible_moe_impls": [FP8BlockScaleMoe],
2268-
}
2269-
test_moe_quantization_classes(
2270-
num_tokens=4,
2271-
hidden_size=1024,
2272-
intermediate_size=1024,
2273-
moe_impl=FP8BlockScaleMoe(),
2274-
routing_config=routing_config,
2275-
weight_processing=weight_processing,
2276-
gated_act_type=GatedActType.SwiGlu,
2277-
cache_permute_indices=cache_permute_indices,
2278-
)

0 commit comments

Comments
 (0)