Skip to content

Commit 4b13d36

Browse files
committed
wip maybe not relevant
1 parent 7473296 commit 4b13d36

File tree

6 files changed

+194
-34
lines changed

6 files changed

+194
-34
lines changed

chatgpt.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import subprocess
2+
import re
3+
4+
# Values to replace the last parameter (1)
5+
values = [
6+
1, 2, 4, 8, 16, 24, 32, 48, 64,
7+
96, 128, 256, 512, 1024, 1536,
8+
2048, 3072, 4096
9+
]
10+
11+
base_cmd = (
12+
"pytest -s "
13+
"'tests/test_trtllm_cutlass_fused_moe.py::"
14+
"test_moe_nvfp4[True-True-otype0-wtype0-256-8-256-7168-{}]'"
15+
)
16+
17+
time_pattern = re.compile(r"Elapsed time: ([\d.]+) ms")
18+
19+
results = []
20+
21+
for v in values:
22+
print(f"Running with last param = {v}")
23+
cmd = base_cmd.format(v)
24+
try:
25+
output = subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT, text=True)
26+
match = time_pattern.search(output)
27+
if match:
28+
elapsed_time = float(match.group(1))
29+
else:
30+
elapsed_time = None
31+
print(f"Warning: Elapsed time not found in output for {v}")
32+
except subprocess.CalledProcessError as e:
33+
output = e.output
34+
elapsed_time = None
35+
print(f"Error running test for {v}:\n{output}")
36+
37+
results.append((v, elapsed_time))
38+
39+
# Print results as a table
40+
print("\nResults:")
41+
print(f"{'Value':>6} | {'Time (ms)':>10}")
42+
print("-" * 20)
43+
for val, time in results:
44+
time_str = f"{time:.2f}" if time is not None else "N/A"
45+
print(f"{val:6} | {time_str:>10}")
46+

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm90(
332332

333333
std::vector<CutlassGemmConfig> get_candidate_configs_sm100(
334334
CutlassGemmConfig::CandidateConfigTypeParam const config) {
335-
#ifdef FAST_BUILD
335+
#ifdef False //FAST_BUILD
336336
// Fast build disables all configs except this one for SM100
337337
return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
338338
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,

csrc/nv_internal/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ struct GemmProfilerBackend {
835835
mWType = wtype;
836836
mOType = otype;
837837
mNumExperts = num_experts;
838-
mNumExpertsPerNode = num_experts / (parallelism_config.ep_size * parallelism_config.tp_size);
838+
mNumExpertsPerNode = num_experts / (parallelism_config.ep_size);// * parallelism_config.tp_size);
839839
mK = k;
840840
mExpertHiddenSize = hidden_size;
841841
mExpertInterSize = inter_size;

flashinfer/autotuner.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ def search_cache(
382382
cache_key = r.get_cache_key(custom_op, input_shapes, tuning_config)
383383

384384
if cache_key in self.profiling_cache:
385+
# print(f"self.profiling_cache:{len(self.profiling_cache)}")
386+
# # print("cache hit", cache_key)
387+
# print(tuning_config)
385388
return True, *self.profiling_cache[cache_key]
386389

387390
return False, 0, -1, None
@@ -452,9 +455,13 @@ def choose_one(
452455
)
453456
# Record the total configs to try
454457
self.stats.tuned_op_total_configs[custom_op] = len(profiles)
458+
# print("xxx"*20)
459+
# print(f"profiles:{len(profiles)}")
455460

456461
for p in profiles:
457462
tensors = self._prepare_input_tensors(p, inputs)
463+
# [print(i.shape) for i in tensors]
464+
# [print(i.dtype) for i in tensors]
458465
is_cache_hit, runner, tactic, _ = self.search_cache(
459466
custom_op, runners, p.get_opt_shapes(), tuning_config
460467
)
@@ -464,17 +471,20 @@ def choose_one(
464471
runner, tactic = None, None
465472
for runner_id, r in enumerate(runners):
466473
# TODO: use FakeTensor here.
474+
# [print(t.shape) for t in tensors]
467475
valid_tactics = r.get_valid_tactics(tensors)
468476
runner_arg_names = {
469477
p.name for p in inspect.signature(r.forward).parameters.values()
470478
}
471479
if "do_preparation" in runner_arg_names and len(valid_tactics) > 0:
472480
r(tensors, tactic=-1, do_preparation=True, **kwargs)
481+
# print(f"valid_tactics: {len(valid_tactics)}")
473482
for tac in valid_tactics:
474483
try:
475484
time_measured = self._profile_single_kernel(
476485
r, tensors, tac, **kwargs
477486
)
487+
# print(f"time_measured: {time_measured}, {tac}")
478488
except Exception as e:
479489
logger.error(
480490
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={[t.size() for t in tensors]}. Error occurred: {e}"
@@ -508,13 +518,16 @@ def choose_one(
508518
logger.debug(
509519
f"[Autotuner]: profiling chosen runner: {runner} {tactic} for {cache_key}"
510520
)
521+
# print(f"[Autotuner]: profiling chosen runner: {runner} {tactic} for {cache_key}")
522+
511523

512524
# Get the best runner and tactic from cache
513525
# If no valid tactic is found, the fallback runner and tactic will be used
526+
# print("search cache")
514527
_, runner_id, tactic, _ = self.search_cache(
515528
custom_op, runners, input_shapes, tuning_config
516529
)
517-
530+
# print(f"returning tactic: {tactic} for {runners[runner_id]}")
518531
return runners[runner_id], tactic
519532

520533
def _profile_single_kernel(

flashinfer/fused_moe.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def gen_fused_moe_sm100_module() -> JitSpec:
9595
"-DCOMPILE_HOPPER_TMA_GEMMS",
9696
],
9797
extra_cflags=[
98-
"-DFAST_BUILD",
98+
# "-DFAST_BUILD",
9999
],
100100
extra_ldflags=["-lcuda"],
101101
extra_include_paths=[
@@ -195,7 +195,6 @@ def get_valid_tactics(
195195
invalid = (m > 128 and min_latency_mode) or (
196196
m <= 128 and min_latency_mode and (not self._is_nvfp4)
197197
)
198-
199198
return (
200199
[] if invalid else list(range(self._fused_moe_runner.get_tactic_num()))
201200
)
@@ -210,6 +209,10 @@ def forward(
210209
x, fc1_expert_weights, fc2_expert_weights, min_latency_mode_tensor = inputs
211210
min_latency_mode = min_latency_mode_tensor.size(0) == 1
212211
# determine if we should use min latency mode according to the profiled seq len
212+
# print("uuuu"*10)
213+
# import traceback
214+
# traceback.print_stack()
215+
# print(f"do_preparation: {do_preparation}, gemm_idx: {gemm_idx}, tactic: {tactic}")
213216
self._fused_moe_runner.run_gemm_profile(
214217
x,
215218
fc1_expert_weights,
@@ -309,7 +312,10 @@ def next_positive_power_of_2(x: int) -> int:
309312
[input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor],
310313
gemm_idx=2,
311314
)
312-
315+
# print(f"input:{input.shape}")
316+
# print(f"fc1_expert_weights:{fc1_expert_weights.shape}")
317+
# print(f"fc2_expert_weights:{fc2_expert_weights.shape}")
318+
print(gemm_tactic_1, gemm_tactic_2)
313319
run_moe = (
314320
moe_runner._fused_moe_runner.run_moe_min_latency
315321
if min_latency_mode

tests/test_trtllm_cutlass_fused_moe.py

Lines changed: 123 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -217,16 +217,31 @@ def test_moe(batch_size, hidden_size, num_experts, top_k, intermediate_size):
217217
num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights
218218
)
219219
flash_output = torch.empty_like(ref_output)
220-
flash_output = fused_moe.cutlass_fused_moe(
221-
x,
222-
selected_experts.to(torch.int),
223-
routing_weights,
224-
w31_weight,
225-
w2_weight,
226-
flash_output.dtype,
227-
output=flash_output,
228-
quant_scales=None,
229-
)
220+
221+
from flashinfer.autotuner import autotune
222+
with torch.inference_mode(), autotune():
223+
flash_output = fused_moe.cutlass_fused_moe(
224+
x,
225+
selected_experts.to(torch.int),
226+
routing_weights,
227+
w31_weight,
228+
w2_weight,
229+
flash_output.dtype,
230+
output=flash_output,
231+
quant_scales=None,
232+
)
233+
print("xxx"*100)
234+
flash_output2 = torch.empty_like(ref_output)
235+
flash_output2 = fused_moe.cutlass_fused_moe(
236+
x,
237+
selected_experts.to(torch.int),
238+
routing_weights,
239+
w31_weight,
240+
w2_weight,
241+
ref_output.dtype,
242+
output=flash_output2,
243+
quant_scales=None,
244+
)
230245
torch.testing.assert_close(ref_output, flash_output[0], rtol=1e-2, atol=1e-2)
231246

232247

@@ -308,16 +323,27 @@ def test_moe_fp8(
308323
torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1)
309324

310325

311-
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
312-
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
313-
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
314-
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
315-
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
326+
# @pytest.mark.parametrize("batch_size", BATCH_SIZES)
327+
# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
328+
# @pytest.mark.parametrize("num_experts", NUM_EXPERTS)
329+
# @pytest.mark.parametrize("top_k", TOP_K_VALUES)
330+
# @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
331+
# @pytest.mark.parametrize(
332+
# "otype, wtype",
333+
# [(torch.float16, torch.float8_e4m3fn), (torch.bfloat16, torch.float8_e4m3fn)],
334+
# )
335+
336+
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096])
337+
@pytest.mark.parametrize("hidden_size", [7168])
338+
@pytest.mark.parametrize("num_experts", [256])
339+
@pytest.mark.parametrize("top_k", [8])
340+
@pytest.mark.parametrize("intermediate_size", [256])
316341
@pytest.mark.parametrize(
317342
"otype, wtype",
318-
[(torch.float16, torch.float8_e4m3fn), (torch.bfloat16, torch.float8_e4m3fn)],
343+
[(torch.bfloat16, torch.float8_e4m3fn)],
319344
)
320345
@pytest.mark.parametrize("quantized_input", [False, True])
346+
@pytest.mark.parametrize("use_autotune", [False, True])
321347
def test_moe_nvfp4(
322348
batch_size,
323349
hidden_size,
@@ -327,6 +353,7 @@ def test_moe_nvfp4(
327353
otype,
328354
wtype,
329355
quantized_input,
356+
use_autotune,
330357
):
331358
# Skip invalid configurations
332359
if top_k > num_experts:
@@ -410,17 +437,85 @@ def test_moe_nvfp4(
410437
input_sf = None
411438
if quantized_input:
412439
hidden_states, input_sf = fp4_quantize(x, a1_gs)
413-
_ = fused_moe.cutlass_fused_moe(
414-
hidden_states,
415-
selected_experts.to(torch.int),
416-
routing_weights,
417-
w1_q.contiguous().view(torch.long),
418-
w2_q.contiguous().view(torch.long),
419-
otype,
420-
quant_scales=quant_scales,
421-
input_sf=input_sf,
422-
output=flash_output,
423-
)
440+
print(hidden_states.dtype)
441+
442+
# Timing starts here
443+
runtimes = 6
444+
flash_output2 = torch.zeros_like(x)
445+
if not use_autotune:
446+
# warmup
447+
for _ in range(runtimes):
448+
_ = fused_moe.cutlass_fused_moe(
449+
hidden_states,
450+
selected_experts.to(torch.int),
451+
routing_weights,
452+
w1_q.contiguous().view(torch.long),
453+
w2_q.contiguous().view(torch.long),
454+
otype,
455+
quant_scales=quant_scales,
456+
input_sf=input_sf,
457+
output=flash_output2,
458+
)
459+
start_event = torch.cuda.Event(enable_timing=True)
460+
end_event = torch.cuda.Event(enable_timing=True)
461+
start_event.record()
462+
for _ in range(runtimes):
463+
_ = fused_moe.cutlass_fused_moe(
464+
hidden_states,
465+
selected_experts.to(torch.int),
466+
routing_weights,
467+
w1_q.contiguous().view(torch.long),
468+
w2_q.contiguous().view(torch.long),
469+
otype,
470+
quant_scales=quant_scales,
471+
input_sf=input_sf,
472+
output=flash_output2,
473+
)
474+
end_event.record()
475+
476+
# Wait for completion
477+
torch.cuda.synchronize()
478+
elapsed_time_ms = start_event.elapsed_time(end_event) / runtimes
479+
print(f"No autotune Elapsed time: {elapsed_time_ms:.2f} ms")
480+
else:
481+
from flashinfer.autotuner import autotune, AutoTuner
482+
AutoTuner.get().clear_cache()
483+
with torch.inference_mode(), autotune():
484+
for _ in range(5):
485+
_ = fused_moe.cutlass_fused_moe(
486+
hidden_states,
487+
selected_experts.to(torch.int),
488+
routing_weights,
489+
w1_q.contiguous().view(torch.long),
490+
w2_q.contiguous().view(torch.long),
491+
otype,
492+
quant_scales=quant_scales,
493+
input_sf=input_sf,
494+
output=flash_output,
495+
)
496+
# Timing starts here
497+
498+
start_event = torch.cuda.Event(enable_timing=True)
499+
end_event = torch.cuda.Event(enable_timing=True)
500+
start_event.record()
501+
for _ in range(runtimes):
502+
_ = fused_moe.cutlass_fused_moe(
503+
hidden_states,
504+
selected_experts.to(torch.int),
505+
routing_weights,
506+
w1_q.contiguous().view(torch.long),
507+
w2_q.contiguous().view(torch.long),
508+
otype,
509+
quant_scales=quant_scales,
510+
input_sf=input_sf,
511+
output=flash_output2,
512+
)
513+
end_event.record()
514+
515+
# Wait for completion
516+
torch.cuda.synchronize()
517+
elapsed_time_ms = start_event.elapsed_time(end_event) / runtimes
518+
print(f"Elapsed time: {elapsed_time_ms:.2f} ms")
424519

425520
# Ref check
426521
a_fp4, a_scale_interleaved = fp4_quantize(x, a1_gs)
@@ -462,7 +557,7 @@ def test_moe_nvfp4(
462557
ref_output = torch_moe_nvfp4(
463558
a_in_dtype, w1_d, w2_d, top_k, routing_weights, selected_experts
464559
)
465-
torch.testing.assert_close(ref_output, flash_output, rtol=2e-1, atol=2e-1)
560+
# torch.testing.assert_close(ref_output, flash_output, rtol=2e-1, atol=2e-1)
466561

467562

468563
@pytest.mark.parametrize("batch_size", BATCH_SIZES)

0 commit comments

Comments
 (0)