Skip to content

Commit 49d744d

Browse files
authored
Add flags to trim down AoT builds (#1393)
1 parent a6a1e49 commit 49d744d

File tree

1 file changed

+131
-28
lines changed

1 file changed

+131
-28
lines changed

flashinfer/aot.py

Lines changed: 131 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from typing import List, Tuple
77

88
import torch
9+
import torch.version
910
from torch.utils.cpp_extension import _get_cuda_arch_flags
1011

1112
from .activation import act_func_def_str, gen_act_and_mul_module
1213
from .cascade import gen_cascade_module
13-
from .comm import gen_trtllm_comm_module, gen_vllm_comm_module
1414
from .fp4_quantization import gen_fp4_quantization_sm100_module
1515
from .fused_moe import gen_cutlass_fused_moe_sm100_module
1616
from .gemm import gen_gemm_module, gen_gemm_sm90_module, gen_gemm_sm100_module
@@ -42,11 +42,15 @@ def gen_fa2(
4242
head_dim_vo: int,
4343
use_sliding_window: bool,
4444
use_logits_soft_cap: bool,
45+
use_attention_sink: bool,
4546
) -> List[JitSpec]:
4647
if dtype_qo.itemsize == dtype_kv.itemsize and dtype_qo != dtype_kv:
4748
return []
4849
if dtype_qo.itemsize == 1:
4950
return [] # fp8 tensor cores not supported in fa2
51+
52+
# TODO: support for AoT sink attention.
53+
5054
return [
5155
gen_single_prefill_module(
5256
backend="fa2",
@@ -105,6 +109,7 @@ def gen_fa3(
105109
head_dim_vo: int,
106110
use_sliding_window: bool,
107111
use_logits_soft_cap: bool,
112+
use_attention_sink: bool,
108113
) -> List[JitSpec]:
109114
if dtype_q != dtype_kv:
110115
return [] # fa3 template do not support mixed precision
@@ -116,6 +121,8 @@ def gen_fa3(
116121
if head_dim_qk == 192 or head_dim_qk == 64:
117122
return [] # (192, 128) & (64, 64) not supported for fp8 yet.
118123

124+
# TODO: support for AoT sink attention.
125+
119126
return [
120127
gen_single_prefill_module(
121128
backend="fa3",
@@ -155,6 +162,7 @@ def gen_attention(
155162
has_sm90: bool,
156163
has_sm100: bool,
157164
add_gemma: bool,
165+
add_oai_oss: bool,
158166
) -> List[JitSpec]:
159167
head_dim_ckv = 512
160168
head_dim_kpe = 64
@@ -181,6 +189,7 @@ def gen_attention(
181189
head_dim_vo=head_dim_vo,
182190
use_sliding_window=use_sliding_window,
183191
use_logits_soft_cap=use_logits_soft_cap,
192+
use_attention_sink=False,
184193
)
185194

186195
# FA3 MHA / MQA / GQA
@@ -206,6 +215,7 @@ def gen_attention(
206215
head_dim_vo=head_dim_vo,
207216
use_sliding_window=use_sliding_window,
208217
use_logits_soft_cap=use_logits_soft_cap,
218+
use_attention_sink=False,
209219
)
210220

211221
# Gemma
@@ -226,6 +236,7 @@ def gen_attention(
226236
head_dim_vo=256,
227237
use_sliding_window=use_sliding_window,
228238
use_logits_soft_cap=use_logits_soft_cap,
239+
use_attention_sink=False,
229240
)
230241
if has_sm90:
231242
for (
@@ -245,8 +256,30 @@ def gen_attention(
245256
head_dim_vo=256,
246257
use_sliding_window=use_sliding_window,
247258
use_logits_soft_cap=use_logits_soft_cap,
259+
use_attention_sink=False,
248260
)
249261

262+
# OAI OSS
263+
if add_oai_oss:
264+
for (
265+
dtype_qo,
266+
dtype_kv,
267+
use_sliding_window,
268+
) in product(
269+
f16_dtype_,
270+
f16_dtype_ + f8_dtype_,
271+
[True],
272+
):
273+
jit_specs += gen_fa2(
274+
dtype_qo=dtype_qo,
275+
dtype_kv=dtype_kv,
276+
head_dim_qk=64,
277+
head_dim_vo=64,
278+
use_sliding_window=use_sliding_window,
279+
use_logits_soft_cap=False,
280+
use_attention_sink=True,
281+
)
282+
250283
# fmha_cutlass_sm100a
251284
# NOTE: currently there's only one uri.
252285
if has_sm100:
@@ -301,7 +334,12 @@ def gen_all_modules(
301334
use_logits_soft_cap_: List[bool],
302335
has_sm90: bool,
303336
has_sm100: bool,
337+
add_comm: bool,
304338
add_gemma: bool,
339+
add_oai_oss: bool,
340+
add_moe: bool,
341+
add_act: bool,
342+
add_misc: bool,
305343
) -> List[JitSpec]:
306344
jit_specs: List[JitSpec] = []
307345

@@ -315,27 +353,40 @@ def gen_all_modules(
315353
has_sm90,
316354
has_sm100,
317355
add_gemma,
356+
add_oai_oss,
318357
)
319-
for act_name in act_func_def_str:
320-
jit_specs.append(gen_act_and_mul_module(act_name))
321-
jit_specs.append(gen_gemm_module())
358+
359+
if add_act:
360+
for act_name in act_func_def_str:
361+
jit_specs.append(gen_act_and_mul_module(act_name))
362+
363+
if add_moe:
364+
jit_specs.append(gen_gemm_module())
365+
if has_sm90:
366+
jit_specs.append(gen_gemm_sm90_module())
367+
if has_sm100:
368+
jit_specs.append(gen_cutlass_fused_moe_sm100_module())
369+
jit_specs.append(gen_fp4_quantization_sm100_module())
370+
jit_specs.append(gen_gemm_sm100_module())
371+
372+
if add_comm:
373+
from .comm import gen_trtllm_comm_module, gen_vllm_comm_module
374+
375+
if has_sm100:
376+
jit_specs.append(gen_trtllm_comm_module())
377+
jit_specs.append(gen_vllm_comm_module())
378+
379+
if add_misc:
380+
jit_specs += [
381+
gen_cascade_module(),
382+
gen_norm_module(),
383+
gen_page_module(),
384+
gen_quantization_module(),
385+
gen_rope_module(),
386+
gen_sampling_module(),
387+
]
322388
if has_sm90:
323-
jit_specs.append(gen_gemm_sm90_module())
324-
if has_sm100:
325-
jit_specs.append(gen_cutlass_fused_moe_sm100_module())
326-
jit_specs.append(gen_fp4_quantization_sm100_module())
327-
jit_specs.append(gen_gemm_sm100_module())
328-
jit_specs.append(gen_trtllm_comm_module())
329-
330-
jit_specs += [
331-
gen_cascade_module(),
332-
gen_vllm_comm_module(),
333-
gen_norm_module(),
334-
gen_page_module(),
335-
gen_quantization_module(),
336-
gen_rope_module(),
337-
gen_sampling_module(),
338-
]
389+
jit_specs.append(get_trtllm_utils_spec())
339390

340391
# dedup
341392
names = set()
@@ -421,11 +472,36 @@ def main():
421472
nargs="*",
422473
help="Use logits soft cap",
423474
)
475+
parser.add_argument(
476+
"--add-comm",
477+
type=parse_bool,
478+
help="Add communication kernels (trtllm_comm, vllm_comm)",
479+
)
424480
parser.add_argument(
425481
"--add-gemma",
426482
type=parse_bool,
427483
help="Add kernels for Gemma Model (head_dim=256, use_sliding_window, use_logits_soft_cap)",
428484
)
485+
parser.add_argument(
486+
"--add-oai-oss",
487+
type=parse_bool,
488+
help="Add kernels for OAI OSS Model (head_dim=64, use_sliding_window)",
489+
)
490+
parser.add_argument(
491+
"--add-moe",
492+
type=parse_bool,
493+
help="Add MoE kernels",
494+
)
495+
parser.add_argument(
496+
"--add-act",
497+
type=parse_bool,
498+
help="Add activation kernels",
499+
)
500+
parser.add_argument(
501+
"--add-misc",
502+
type=parse_bool,
503+
help="Add miscellaneous kernels",
504+
)
429505
args = parser.parse_args()
430506

431507
# Default values
@@ -459,7 +535,12 @@ def main():
459535
False,
460536
# True,
461537
]
538+
add_comm = False
462539
add_gemma = True
540+
add_oai_oss = True
541+
add_moe = False
542+
add_act = True
543+
add_misc = True
463544

464545
# Override
465546
if args.out_dir:
@@ -478,19 +559,33 @@ def main():
478559
use_sliding_window_ = [parse_bool(s) for s in args.use_sliding_window]
479560
if args.use_logits_soft_cap:
480561
use_logits_soft_cap_ = [parse_bool(s) for s in args.use_logits_soft_cap]
562+
if args.add_comm is not None:
563+
add_comm = bool(args.add_comm)
481564
if args.add_gemma is not None:
482565
add_gemma = bool(args.add_gemma)
566+
if args.add_oai_oss is not None:
567+
add_oai_oss = bool(args.add_oai_oss)
568+
if args.add_moe is not None:
569+
add_moe = bool(args.add_moe)
570+
if args.add_act is not None:
571+
add_act = bool(args.add_act)
572+
if args.add_misc is not None:
573+
add_misc = bool(args.add_misc)
483574

484575
# Cuda Arch
485576
if "TORCH_CUDA_ARCH_LIST" not in os.environ:
486577
raise RuntimeError("Please explicitly set env var TORCH_CUDA_ARCH_LIST.")
487578
gencode_flags = _get_cuda_arch_flags()
488-
has_sm90 = any("compute_90" in flag for flag in gencode_flags) and version_at_least(
489-
torch.version.cuda, "12.3"
490-
)
491-
has_sm100 = any(
492-
"compute_100" in flag for flag in gencode_flags
493-
) and version_at_least(torch.version.cuda, "12.8")
579+
580+
def has_sm(compute: str, version: str) -> bool:
581+
if not any("compute_90" in flag for flag in gencode_flags):
582+
return False
583+
if torch.version.cuda is None:
584+
return True
585+
return version_at_least(torch.version.cuda, version)
586+
587+
has_sm90 = has_sm("compute_90", "12.3")
588+
has_sm100 = has_sm("compute_100", "12.8")
494589

495590
# Update data dir
496591
jit_env.FLASHINFER_CSRC_DIR = project_root / "csrc"
@@ -521,7 +616,12 @@ def main():
521616
print(" TORCH_CUDA_ARCH_LIST:", os.environ["TORCH_CUDA_ARCH_LIST"])
522617
print(" has_sm90:", has_sm90)
523618
print(" has_sm100:", has_sm100)
619+
print(" add_comm:", add_comm)
524620
print(" add_gemma:", add_gemma)
621+
print(" add_oai_oss:", add_oai_oss)
622+
print(" add_moe:", add_moe)
623+
print(" add_act:", add_act)
624+
print(" add_misc:", add_misc)
525625

526626
# Generate JIT specs
527627
print("Generating JIT specs...")
@@ -537,8 +637,6 @@ def main():
537637
],
538638
)
539639
]
540-
if has_sm90:
541-
jit_specs.append(get_trtllm_utils_spec())
542640
jit_specs += gen_all_modules(
543641
f16_dtype_,
544642
f8_dtype_,
@@ -548,7 +646,12 @@ def main():
548646
use_logits_soft_cap_,
549647
has_sm90,
550648
has_sm100,
649+
add_comm,
551650
add_gemma,
651+
add_oai_oss,
652+
add_moe,
653+
add_act,
654+
add_misc,
552655
)
553656
print("Total ops:", len(jit_specs))
554657

0 commit comments

Comments
 (0)