Skip to content

Commit 1649e23

Browse files
authored
bugfix: collect all modules to aot (#1622)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Fix #1556 ## πŸ” Related Issues #1556 <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent dae1a0f commit 1649e23

File tree

13 files changed

+143
-55
lines changed

13 files changed

+143
-55
lines changed

β€Žcsrc/xqa/barriers.cuhβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
*/
1212

1313
#pragma once
14+
#include <cassert>
15+
1416
#include "cuda_hint.cuh"
1517
#include "defines.h"
1618
#if !USE_CUSTOM_BARRIER

β€Žflashinfer/aot.pyβ€Ž

Lines changed: 112 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch.version
1010

1111
from .activation import act_func_def_str, gen_act_and_mul_module
12+
from .fp8_quantization import gen_mxfp8_quantization_sm100_module
1213
from .cascade import gen_cascade_module
1314
from .fp4_quantization import (
1415
gen_fp4_quantization_sm100_module,
@@ -17,27 +18,36 @@
1718
from .fused_moe import (
1819
gen_cutlass_fused_moe_sm100_module,
1920
gen_cutlass_fused_moe_sm90_module,
21+
gen_trtllm_gen_fused_moe_sm100_module,
22+
)
23+
from .gemm import (
24+
gen_gemm_module,
25+
gen_gemm_sm90_module,
26+
gen_gemm_sm100_module,
27+
gen_gemm_sm100_module_cutlass_fp4,
28+
gen_gemm_sm100_module_cutlass_fp8,
29+
gen_trtllm_gen_gemm_module,
2030
)
21-
from .gemm import gen_gemm_module, gen_gemm_sm90_module, gen_gemm_sm100_module
2231
from .jit import JitSpec, build_jit_specs
2332
from .jit import env as jit_env
2433
from .jit import (
2534
gen_batch_decode_module,
2635
gen_batch_mla_module,
2736
gen_batch_prefill_module,
2837
gen_fmha_cutlass_sm100a_module,
29-
gen_jit_spec,
3038
gen_single_decode_module,
3139
gen_single_prefill_module,
40+
gen_trtllm_gen_fmha_module,
3241
)
3342
from .mla import gen_mla_module
3443
from .norm import gen_norm_module
3544
from .page import gen_page_module
3645
from .quantization import gen_quantization_module
3746
from .rope import gen_rope_module
3847
from .sampling import gen_sampling_module
39-
from .tllm_utils import get_trtllm_utils_spec
40-
from .utils import version_at_least
48+
from .tllm_utils import gen_trtllm_utils_module
49+
from .utils import gen_logging_module, version_at_least
50+
from .xqa import gen_xqa_module
4151
from .compilation_context import CompilationContext
4252

4353

@@ -275,6 +285,9 @@ def gen_attention(
275285
use_logits_soft_cap=False,
276286
)
277287

288+
# trtllm_gen_fmha
289+
yield gen_trtllm_gen_fmha_module()
290+
278291
# MLA
279292
# NOTE: fp8 kv not supported in MLA
280293
mla_backend_ = ["fa2"] + (["fa3"] if has_sm90 else [])
@@ -296,6 +309,46 @@ def gen_attention(
296309
yield gen_mla_module()
297310

298311

312+
def gen_xqa(
313+
use_fp16_: List[bool],
314+
token_per_page_: List[int],
315+
head_size_: List[int],
316+
head_grp_size_: List[int],
317+
use_sliding_window_: List[bool],
318+
has_sm90: bool,
319+
) -> Iterator[JitSpec]:
320+
"""Generate XQA modules for various configurations."""
321+
if not has_sm90:
322+
return # XQA requires SM90+
323+
324+
for (
325+
use_fp16,
326+
token_per_page,
327+
head_size,
328+
head_grp_size,
329+
use_sliding_window,
330+
) in product(
331+
use_fp16_,
332+
token_per_page_,
333+
head_size_,
334+
head_grp_size_,
335+
use_sliding_window_,
336+
):
337+
# Skip invalid configurations
338+
if head_size % 16 != 0 or head_size > 256 or head_size < 16:
339+
continue
340+
if token_per_page not in [16, 32, 64, 128]:
341+
continue
342+
343+
yield gen_xqa_module(
344+
use_fp16=use_fp16,
345+
token_per_page=token_per_page,
346+
head_size=head_size,
347+
head_grp_size=head_grp_size,
348+
use_sliding_window=use_sliding_window,
349+
)
350+
351+
299352
def gen_all_modules(
300353
f16_dtype_: List[torch.dtype],
301354
f8_dtype_: List[torch.dtype],
@@ -311,6 +364,7 @@ def gen_all_modules(
311364
add_moe: bool,
312365
add_act: bool,
313366
add_misc: bool,
367+
add_xqa: bool,
314368
) -> List[JitSpec]:
315369
jit_specs: List[JitSpec] = []
316370

@@ -343,14 +397,23 @@ def gen_all_modules(
343397
jit_specs.append(gen_fp4_quantization_sm100_module())
344398
jit_specs.append(gen_cutlass_fused_moe_sm100_module())
345399
jit_specs.append(gen_gemm_sm100_module())
400+
jit_specs.append(gen_gemm_sm100_module_cutlass_fp4())
401+
jit_specs.append(gen_gemm_sm100_module_cutlass_fp8())
402+
jit_specs.append(gen_mxfp8_quantization_sm100_module())
403+
jit_specs.append(gen_trtllm_gen_gemm_module())
404+
jit_specs.append(gen_trtllm_gen_fused_moe_sm100_module())
346405

347406
if add_comm:
348407
from .comm import gen_trtllm_comm_module, gen_vllm_comm_module
349408
from .comm.nvshmem import gen_nvshmem_module
409+
from .comm.trtllm_alltoall import gen_comm_alltoall_module
410+
from .comm.trtllm_mnnvl_ar import gen_trtllm_mnnvl_comm_module
350411

351412
jit_specs.append(gen_nvshmem_module())
413+
jit_specs.append(gen_comm_alltoall_module())
352414
if has_sm100:
353415
jit_specs.append(gen_trtllm_comm_module())
416+
jit_specs.append(gen_trtllm_mnnvl_comm_module())
354417
jit_specs.append(gen_vllm_comm_module())
355418

356419
if add_misc:
@@ -363,7 +426,25 @@ def gen_all_modules(
363426
gen_sampling_module(),
364427
]
365428
if has_sm90:
366-
jit_specs.append(get_trtllm_utils_spec())
429+
jit_specs.append(gen_trtllm_utils_module())
430+
431+
if add_xqa:
432+
# Define XQA configurations to iterate over
433+
xqa_use_fp16_ = [True, False] # fp16 and bf16
434+
xqa_token_per_page_ = [16, 32, 64, 128]
435+
xqa_head_size_ = [64, 128, 256]
436+
xqa_head_grp_size_ = [1, 2, 4, 8] # Different group sizes for MQA/GQA
437+
438+
jit_specs += list(
439+
gen_xqa(
440+
xqa_use_fp16_,
441+
xqa_token_per_page_,
442+
xqa_head_size_,
443+
xqa_head_grp_size_,
444+
use_sliding_window_,
445+
has_sm90,
446+
)
447+
)
367448

368449
# dedup
369450
names = set()
@@ -479,6 +560,11 @@ def main():
479560
type=parse_bool,
480561
help="Add miscellaneous kernels",
481562
)
563+
parser.add_argument(
564+
"--add-xqa",
565+
type=parse_bool,
566+
help="Add XQA (Cross-Query Attention) kernels",
567+
)
482568
args = parser.parse_args()
483569

484570
# Default values
@@ -488,13 +574,13 @@ def main():
488574
fa2_head_dim_ = [
489575
(64, 64),
490576
(128, 128),
491-
# (256, 256),
577+
(256, 256),
492578
]
493579
fa3_head_dim_ = [
494580
(192, 128),
495581
(128, 128),
496-
# (64, 64),
497-
# (256, 256),
582+
(64, 64),
583+
(256, 256),
498584
]
499585
f16_dtype_ = [
500586
torch.float16,
@@ -506,18 +592,19 @@ def main():
506592
]
507593
use_sliding_window_ = [
508594
False,
509-
# True,
595+
True,
510596
]
511597
use_logits_soft_cap_ = [
512598
False,
513-
# True,
599+
True,
514600
]
515-
add_comm = False
516-
add_gemma = False
601+
add_comm = True
602+
add_gemma = True
517603
add_oai_oss = True
518-
add_moe = False
519-
add_act = False
604+
add_moe = True
605+
add_act = True
520606
add_misc = True
607+
add_xqa = True
521608

522609
# Override
523610
if args.out_dir:
@@ -537,17 +624,19 @@ def main():
537624
if args.use_logits_soft_cap:
538625
use_logits_soft_cap_ = [parse_bool(s) for s in args.use_logits_soft_cap]
539626
if args.add_comm is not None:
540-
add_comm = bool(args.add_comm)
627+
add_comm = args.add_comm
541628
if args.add_gemma is not None:
542-
add_gemma = bool(args.add_gemma)
629+
add_gemma = args.add_gemma
543630
if args.add_oai_oss is not None:
544-
add_oai_oss = bool(args.add_oai_oss)
631+
add_oai_oss = args.add_oai_oss
545632
if args.add_moe is not None:
546-
add_moe = bool(args.add_moe)
633+
add_moe = args.add_moe
547634
if args.add_act is not None:
548-
add_act = bool(args.add_act)
635+
add_act = args.add_act
549636
if args.add_misc is not None:
550-
add_misc = bool(args.add_misc)
637+
add_misc = args.add_misc
638+
if args.add_xqa is not None:
639+
add_xqa = args.add_xqa
551640

552641
# Cuda Arch
553642
if "FLASHINFER_CUDA_ARCH_LIST" not in os.environ:
@@ -603,21 +692,11 @@ def has_sm(compute: str, version: str) -> bool:
603692
print(" add_moe:", add_moe)
604693
print(" add_act:", add_act)
605694
print(" add_misc:", add_misc)
695+
print(" add_xqa:", add_xqa)
606696

607697
# Generate JIT specs
608698
print("Generating JIT specs...")
609-
jit_specs = [
610-
gen_jit_spec(
611-
"logging",
612-
[
613-
jit_env.FLASHINFER_CSRC_DIR / "logging.cc",
614-
],
615-
extra_include_paths=[
616-
jit_env.SPDLOG_INCLUDE_DIR,
617-
jit_env.FLASHINFER_INCLUDE_DIR,
618-
],
619-
)
620-
]
699+
jit_specs = [gen_logging_module()]
621700
jit_specs += gen_all_modules(
622701
f16_dtype_,
623702
f8_dtype_,
@@ -633,6 +712,7 @@ def has_sm(compute: str, version: str) -> bool:
633712
add_moe,
634713
add_act,
635714
add_misc,
715+
add_xqa,
636716
)
637717
print("Total ops:", len(jit_specs))
638718

β€Žflashinfer/decode.pyβ€Ž

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
get_batch_prefill_uri,
3333
get_single_decode_uri,
3434
setup_cubin_loader,
35-
trtllm_gen_fmha_module,
35+
gen_trtllm_gen_fmha_module,
3636
)
3737
from .page import get_seq_lens
3838
from .prefill import (
@@ -302,7 +302,7 @@ def _fake_run_batch_decode(
302302

303303
@functools.cache
304304
def get_trtllm_gen_fmha_module():
305-
mod = trtllm_gen_fmha_module()
305+
mod = gen_trtllm_gen_fmha_module()
306306
op = mod.build_and_load()
307307
setup_cubin_loader(mod.get_library_path())
308308
return op
@@ -1810,7 +1810,7 @@ def run(
18101810
class TrtllmGenDecodeModule:
18111811
def __init__(self) -> None:
18121812
self._sm_count: Optional[int] = None
1813-
self._mod = trtllm_gen_fmha_module()
1813+
self._mod = gen_trtllm_gen_fmha_module()
18141814
self._op = self._mod.build_and_load()
18151815
from flashinfer.jit.cubin_loader import setup_cubin_loader
18161816

β€Žflashinfer/fused_moe/__init__.pyβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
cutlass_fused_moe,
2323
gen_cutlass_fused_moe_sm100_module,
2424
gen_cutlass_fused_moe_sm90_module,
25+
gen_trtllm_gen_fused_moe_sm100_module,
2526
reorder_rows_for_gated_act_gemm,
2627
trtllm_fp4_block_scale_moe,
2728
trtllm_fp4_block_scale_routed_moe,

β€Žflashinfer/fused_moe/core.pyβ€Ž

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ def cutlass_fused_moe(
919919
# trtllmgen-moe-fp8
920920

921921

922-
def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
922+
def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
923923
# Fetch "flashinferMetaInfo.h" from the online kernel cache. This file
924924
# contains the `tllmGenBatchedGemmList` as the list of available kernels
925925
# online. It is included when compiling `trtllm_fused_moe_runner.cu`, etc.
@@ -975,7 +975,7 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
975975

976976
@functools.cache
977977
def get_trtllm_moe_sm100_module():
978-
module = trtllm_gen_fused_moe_sm100_module()
978+
module = gen_trtllm_gen_fused_moe_sm100_module()
979979
moe_op = module.build_and_load()
980980
setup_cubin_loader(str(module.get_library_path()))
981981

β€Žflashinfer/gemm.pyβ€Ž

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def get_gemm_sm100_module():
412412
return module
413413

414414

415-
def trtllm_gemm_gen_module() -> JitSpec:
415+
def gen_trtllm_gen_gemm_module() -> JitSpec:
416416
# Fetch "flashinferMetaInfo.h" from the online kernel cache. This file
417417
# contains the `tllmGenGemmList` as the list of available kernels online.
418418
# It is included when compiling `trtllm_gemm_runner.cu`.
@@ -446,7 +446,7 @@ def trtllm_gemm_gen_module() -> JitSpec:
446446

447447
@functools.cache
448448
def get_trtllm_gemm_module():
449-
mod = trtllm_gemm_gen_module()
449+
mod = gen_trtllm_gen_gemm_module()
450450
op = mod.build_and_load()
451451
setup_cubin_loader(mod.get_library_path())
452452
return op
@@ -2019,7 +2019,7 @@ def gemm_fp8_nt_groupwise(
20192019

20202020
@functools.cache
20212021
def get_trtllm_fp4_gemm_module():
2022-
mod = trtllm_gemm_gen_module()
2022+
mod = gen_trtllm_gen_gemm_module()
20232023
op = mod.build_and_load()
20242024
setup_cubin_loader(mod.get_library_path())
20252025

β€Žflashinfer/jit/__init__.pyβ€Ž

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from . import env as env
2424
from .activation import gen_act_and_mul_module as gen_act_and_mul_module
2525
from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str
26-
from .attention import cudnn_fmha_gen_module as cudnn_fmha_gen_module
26+
from .attention import gen_cudnn_fmha_module as gen_cudnn_fmha_module
2727
from .attention import gen_batch_attention_module as gen_batch_attention_module
2828
from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module
2929
from .attention import gen_batch_decode_module as gen_batch_decode_module
@@ -61,7 +61,7 @@
6161
from .attention import get_pod_uri as get_pod_uri
6262
from .attention import get_single_decode_uri as get_single_decode_uri
6363
from .attention import get_single_prefill_uri as get_single_prefill_uri
64-
from .attention import trtllm_gen_fmha_module as trtllm_gen_fmha_module
64+
from .attention import gen_trtllm_gen_fmha_module as gen_trtllm_gen_fmha_module
6565
from .core import JitSpec as JitSpec
6666
from .core import build_jit_specs as build_jit_specs
6767
from .core import clear_cache_dir as clear_cache_dir
@@ -78,7 +78,7 @@
7878

7979
@functools.cache
8080
def get_cudnn_fmha_gen_module():
81-
mod = cudnn_fmha_gen_module()
81+
mod = gen_cudnn_fmha_module()
8282
op = mod.build_and_load()
8383
setup_cubin_loader(mod.get_library_path())
8484
return op

0 commit comments

Comments
Β (0)