Skip to content

Commit bc1a041

Browse files
authored
feat: Support loading autotuned results from json for cutlass fp4 moe backends (#1310)
This PR adds support for loading autotuned results from JSON files for the Cutlass FP4 MoE backends. The script `benchmarks/bench_cutlass_fused_moe.py` generates a JSON file at `configs/<flashinfer_version>/trtllm_fused_moe_<device_name>.json`, mapping input shapes to the optimal config/tactic for GEMMs used in `fused_moe.cutlass_fused_moe`. At runtime, setting the `FLASHINFER_AUTOTUNER_LOAD_FROM_FILE` environment variable enables loading from this file. If the variable is unset or a matching entry is not found, it falls back to the default config/tactic. Configs are organized by flashinfer version and GPU device. cc. @yzh119 @wenscarl @kushanam
1 parent 0ef9659 commit bc1a041

File tree

6 files changed

+226
-54
lines changed

6 files changed

+226
-54
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ repos:
3939
rev: 24.8.0
4040
hooks:
4141
- id: black
42+
exclude: flashinfer/tuning_configs/.*\.py
4243

4344
- repo: https://github.com/pycqa/isort
4445
rev: 5.13.2

benchmarks/bench_cutlass_fused_moe.py

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,20 @@
1414
limitations under the License.
1515
"""
1616

17+
import argparse
18+
import pprint
19+
1720
import torch
1821
from torch.nn import functional as F
19-
from triton.testing import do_bench
2022

21-
import flashinfer
2223
import flashinfer.fused_moe as fused_moe
2324
from flashinfer import fp4_quantize
25+
from flashinfer.autotuner import AutoTuner, autotune, get_config_path
26+
from flashinfer.testing.utils import bench_gpu_time_with_cudagraph
2427

25-
BATCH_SIZES = [
26-
1,
27-
2,
28-
4,
29-
8,
30-
16,
31-
24,
32-
32,
33-
48,
34-
64,
35-
96,
36-
128,
37-
256,
38-
512,
39-
1024,
40-
1536,
41-
2048,
42-
3072,
43-
4096,
44-
]
45-
46-
configs = []
47-
hidden_size = 7168
48-
num_experts = [32, 256]
49-
top_k = [8]
50-
intermediate_size = [256, 2048]
5128
FLOAT4_E2M1_MAX = 6.0
5229
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
53-
FP8_DTYPE = torch.float8_e4m3fn
30+
5431

5532
test_configs = [
5633
{
@@ -96,6 +73,7 @@ def bench_cutlass_fused_moe(
9673
num_experts,
9774
top_k,
9875
intermediate_size,
76+
skip_autotune,
9977
):
10078
torch.manual_seed(42)
10179
quant_blocksize = 16
@@ -165,12 +143,24 @@ def bench_cutlass_fused_moe(
165143
]
166144
hidden_states = x
167145
hidden_states, input_sf = fp4_quantize(x, a1_gs)
168-
repeats = 3
169-
from flashinfer.autotuner import AutoTuner, autotune
170146

171-
AutoTuner.get().clear_cache()
172-
with torch.inference_mode(), autotune():
173-
for _ in range(2):
147+
# Warmup
148+
for _ in range(3):
149+
_ = fused_moe.cutlass_fused_moe(
150+
hidden_states,
151+
selected_experts.to(torch.int),
152+
routing_weights,
153+
w1_q.contiguous().view(torch.long),
154+
w2_q.contiguous().view(torch.long),
155+
otype,
156+
quant_scales=quant_scales,
157+
input_sf=input_sf,
158+
output=flash_output,
159+
tune_max_num_tokens=16384,
160+
)
161+
162+
if not skip_autotune:
163+
with torch.inference_mode(), autotune(True):
174164
_ = fused_moe.cutlass_fused_moe(
175165
hidden_states,
176166
selected_experts.to(torch.int),
@@ -181,8 +171,9 @@ def bench_cutlass_fused_moe(
181171
quant_scales=quant_scales,
182172
input_sf=input_sf,
183173
output=flash_output,
174+
tune_max_num_tokens=16384,
184175
)
185-
ms = do_bench(
176+
ms_list = bench_gpu_time_with_cudagraph(
186177
lambda: fused_moe.cutlass_fused_moe(
187178
hidden_states,
188179
selected_experts.to(torch.int),
@@ -195,23 +186,44 @@ def bench_cutlass_fused_moe(
195186
output=flash_output,
196187
)
197188
)
189+
avg_ms = sum(ms_list) / len(ms_list)
190+
print(f"{'input':<15} {'weight1':<20} {'weight2':<20} {'time(ms)'}")
198191
print(
199-
f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}"
192+
f"{str(tuple(hidden_states.shape)):<15} {str(tuple(w1.shape)):<20} {str(tuple(w2.shape)):<20} {avg_ms:.3f}"
200193
)
201-
print(f"execution time: {ms}ms")
202194

203195

204196
if __name__ == "__main__":
197+
parser = argparse.ArgumentParser()
198+
parser.add_argument(
199+
"--update-config",
200+
action="store_true",
201+
help="Update the config file with the new profiling results",
202+
)
203+
parser.add_argument(
204+
"--num-tokens", type=int, default=32, help="Number of tokens to profile"
205+
)
206+
parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning")
207+
args = parser.parse_args()
208+
AutoTuner.get().clear_cache()
209+
205210
for config in test_configs:
206-
hidden_size = config["hidden_size"]
207-
num_experts = config["num_experts"]
208-
top_k = config["top_k"]
209-
intermediate_size = config["intermediate_size"]
210-
for batch_size in BATCH_SIZES:
211-
bench_cutlass_fused_moe(
212-
batch_size,
213-
hidden_size,
214-
num_experts,
215-
top_k,
216-
intermediate_size,
217-
)
211+
bench_cutlass_fused_moe(
212+
args.num_tokens,
213+
config["hidden_size"],
214+
config["num_experts"],
215+
config["top_k"],
216+
config["intermediate_size"],
217+
args.skip_autotune,
218+
)
219+
220+
configs = AutoTuner.get().profiling_cache
221+
if args.update_config and configs:
222+
# The original key contains a runner's hash in k[2] which might be different across machines.
223+
# So, we remove it for now. v[0] and v[1] are the runner id and the tactic.
224+
converted = {str((k[0], k[1], k[3])): (v[0], v[1]) for k, v in configs.items()}
225+
config_path = get_config_path(is_module=False)
226+
with open(config_path, "w") as f:
227+
f.write("best_configs = ")
228+
pprint.pprint(converted, stream=f)
229+
print(f"Saved the cache to {config_path}")

flashinfer/autotuner.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,39 @@
11
import contextlib
22
import copy
3+
import importlib
34
import inspect
45
import itertools
6+
import os
57
from abc import ABC, abstractmethod
68
from dataclasses import dataclass, field
79
from functools import lru_cache
810
from typing import Any, Callable, Dict, List, Set, Tuple, Union
911

1012
import torch
1113

14+
from flashinfer import __version__ as flashinfer_version
15+
1216
# from tensorrt_llm.bindings.internal.runtime import delay_kernel
1317
# from tensorrt_llm.logger import logger
1418
from flashinfer.tllm_utils import delay_kernel
1519

1620
from .jit.core import logger
1721

1822

23+
def get_config_path(is_module: bool):
24+
dev_name = torch.cuda.get_device_name(0).replace(" ", "_")
25+
fi_ver = flashinfer_version.replace(".", "_")
26+
config_name = f"v{fi_ver}_trtllm_fused_moe_{dev_name}"
27+
if is_module:
28+
return f"flashinfer.tuning_configs.{config_name}"
29+
else:
30+
return os.path.join(
31+
os.path.dirname(os.path.realpath(__file__)),
32+
"tuning_configs",
33+
config_name + ".py",
34+
)
35+
36+
1937
@dataclass(slots=True, unsafe_hash=True)
2038
class DynamicTensorSpec:
2139
"""
@@ -265,6 +283,25 @@ def __str__(self) -> str:
265283
return stats_str
266284

267285

286+
@lru_cache(maxsize=None)
287+
def load_from_file(key):
288+
module_name = get_config_path(is_module=True)
289+
try:
290+
module = importlib.import_module(module_name)
291+
best_configs = module.best_configs
292+
except (ImportError, AttributeError):
293+
best_configs = None
294+
if best_configs is not None:
295+
k = str((key[0], key[1], key[3]))
296+
if k in best_configs:
297+
logger.info(f"[Autotuner]: Loading configs for {k} from file.")
298+
return True, best_configs[k][0], best_configs[k][1], None
299+
logger.info(
300+
f"[Autotuner]: Loading configs for {key} from file failed; Using default configs instead."
301+
)
302+
return False, 0, -1, None
303+
304+
268305
class AutoTuner:
269306
"""AutoTuner for optimizing TensorRT-LLM operations.
270307
@@ -316,11 +353,16 @@ def search_cache(
316353
[is_cache_hit, runner_id, tactic, stored_profile]
317354
"""
318355
for r in runners:
356+
cache_key = AutoTuner._get_cache_key(
357+
custom_op, r, input_shapes, tuning_config
358+
)
319359
if (
320-
cache_key := AutoTuner._get_cache_key(
321-
custom_op, r, input_shapes, tuning_config
322-
)
323-
) in self.profiling_cache:
360+
os.environ.get("FLASHINFER_AUTOTUNER_LOAD_FROM_FILE", "0") == "1"
361+
and not self.is_tuning_mode
362+
):
363+
output = load_from_file(cache_key)
364+
return output
365+
elif cache_key in self.profiling_cache:
324366
return True, *self.profiling_cache[cache_key]
325367

326368
return False, 0, -1, None

flashinfer/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def cutlass_fused_moe(
765765
use_w4a8_group_scaling=use_w4a8_group_scaling,
766766
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
767767
min_latency_mode=min_latency_mode,
768-
tune_max_num_tokens=8192,
768+
tune_max_num_tokens=tune_max_num_tokens,
769769
)
770770

771771

0 commit comments

Comments
 (0)