Skip to content

Commit 11ded11

Browse files
authored
[#8389][fix] Update group attention matching to first map to custom torch attention (#8638)
Signed-off-by: Fridah-nv <[email protected]>
1 parent 70e4d72 commit 11ded11

File tree

4 files changed

+172
-96
lines changed

4 files changed

+172
-96
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ transforms:
3434
match_eager_attention:
3535
stage: pattern_matcher
3636
requires_shape_prop: true
37-
match_grouped_attention_with_repeat_kv:
37+
match_sdpa_to_torch_attention:
3838
stage: pattern_matcher
39-
match_grouped_attention_without_repeat_kv:
39+
match_grouped_attention:
4040
stage: pattern_matcher
4141
match_attention_layout:
4242
stage: pattern_matcher

tensorrt_llm/_torch/auto_deploy/transform/library/attention.py

Lines changed: 166 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from ...models.factory import ModelFactory
1313
from ...shim.interface import CachedSequenceInterface
14-
from ...utils.logger import ad_logger
1514
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
1615
from ..interface import (
1716
BaseTransform,
@@ -374,28 +373,137 @@ def _call_attn(q, k, v, *, is_causal: bool, attn_mask=None, dropout_p=None, scal
374373
return torch.ops.auto_deploy.torch_attention.default(q, k, v, **kwargs)
375374

376375

377-
def make_grouped_attn_pair(
376+
def make_sdpa_to_torch_attn_pair(
378377
*,
379-
repeat_kv: bool,
380378
is_causal: bool,
381379
has_scale: bool,
382380
enable_gqa: bool,
383381
has_attn_mask: bool,
384382
has_dropout: bool,
385383
) -> Tuple[Callable, Callable, List[str]]:
386384
"""
387-
Returns (pattern_fn, replacement_fn, argnames) with exact positional parity.
388-
389-
Arg order rules:
390-
Base: (q, k, v)
391-
+repeat_kv -> insert n_rep after (q, k, v)
392-
+attn_mask -> include attn_mask after n_rep if repeat_kv else after (q, k, v)
393-
+dropout -> include dropout_p after attn_mask or after n_rep/base if no attn_mask
394-
+scale -> include scale last
385+
Returns (pattern_fn, replacement_fn, argnames) for matching SDPA to torch_attention.
386+
387+
Pattern: torch_attention_sdpa --> torch_attention
395388
"""
396389
argnames: List[str] = ["q", "k", "v"]
397-
if repeat_kv:
398-
argnames.append("n_rep")
390+
if has_attn_mask:
391+
argnames.append("attn_mask")
392+
if has_dropout:
393+
argnames.append("dropout_p")
394+
if has_scale:
395+
argnames.append("scale")
396+
397+
def pattern_fn(*args):
398+
if len(args) != len(argnames):
399+
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
400+
m = dict(zip(argnames, args))
401+
return _call_sdpa(
402+
m["q"],
403+
m["k"],
404+
m["v"],
405+
is_causal=is_causal,
406+
enable_gqa=enable_gqa,
407+
attn_mask=m.get("attn_mask"),
408+
dropout_p=m.get("dropout_p"),
409+
scale=m.get("scale"),
410+
)
411+
412+
def replacement_fn(*args):
413+
if len(args) != len(argnames):
414+
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
415+
m = dict(zip(argnames, args))
416+
return _call_attn(
417+
m["q"],
418+
m["k"],
419+
m["v"],
420+
is_causal=is_causal,
421+
attn_mask=m.get("attn_mask"),
422+
dropout_p=m.get("dropout_p"),
423+
scale=m.get("scale"),
424+
)
425+
426+
_attach_signature(pattern_fn, argnames)
427+
_attach_signature(replacement_fn, argnames)
428+
return pattern_fn, replacement_fn, argnames
429+
430+
431+
def generate_and_register_sdpa_to_torch_attn_patterns(patterns, register_ad_pattern: Callable):
432+
"""
433+
Generate patterns for matching SDPA to torch_attention.
434+
Enumerates combinations across:
435+
- is_causal: [False, True]
436+
- has_scale: [False, True]
437+
- enable_gqa: [False, True]
438+
- has_attn_mask: [False, True]
439+
- has_dropout: [False, True]
440+
"""
441+
q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
442+
k = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
443+
v = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
444+
attn_mask_tensor = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16)
445+
446+
dropout_val = 0.12345
447+
scale_val = 0.56789
448+
449+
total = 0
450+
axes = ((False, True),) * 5
451+
for is_causal, has_scale, enable_gqa, has_attn_mask, has_dropout in product(*axes):
452+
pat_fn, rep_fn, argnames = make_sdpa_to_torch_attn_pair(
453+
is_causal=is_causal,
454+
has_scale=has_scale,
455+
enable_gqa=enable_gqa,
456+
has_attn_mask=has_attn_mask,
457+
has_dropout=has_dropout,
458+
)
459+
460+
value_map = {
461+
"q": q,
462+
"k": k,
463+
"v": v,
464+
"attn_mask": attn_mask_tensor,
465+
"dropout_p": dropout_val,
466+
"scale": scale_val,
467+
}
468+
dummy_args: List[object] = []
469+
for name in argnames:
470+
try:
471+
dummy_args.append(value_map[name])
472+
except KeyError:
473+
raise RuntimeError(f"Unexpected arg name: {name}")
474+
475+
scalar_names = {"dropout_p", "scale"}
476+
scalar_workaround: Dict[str, object] = {
477+
n: value_map[n] for n in argnames if n in scalar_names
478+
}
479+
if not scalar_workaround:
480+
scalar_workaround = None
481+
482+
register_ad_pattern(
483+
search_fn=pat_fn,
484+
replace_fn=rep_fn,
485+
patterns=patterns,
486+
dummy_args=dummy_args,
487+
scalar_workaround=scalar_workaround,
488+
)
489+
total += 1
490+
return total
491+
492+
493+
def make_repeat_kv_torch_attn_pair(
494+
*,
495+
is_causal: bool,
496+
has_scale: bool,
497+
has_attn_mask: bool,
498+
has_dropout: bool,
499+
) -> Tuple[Callable, Callable, List[str]]:
500+
"""
501+
Returns (pattern_fn, replacement_fn, argnames) for matching repeat_kv + torch_attention.
502+
503+
Pattern: repeat_kv(k, n_rep), repeat_kv(v, n_rep), torch_attention --> torch_attention
504+
This handles GQA patterns where repeat_kv is explicitly applied before torch_attention.
505+
"""
506+
argnames: List[str] = ["q", "k", "v", "n_rep"]
399507
if has_attn_mask:
400508
argnames.append("attn_mask")
401509
if has_dropout:
@@ -411,30 +519,27 @@ def pattern_fn(*args):
411519
q = m["q"]
412520
k = m["k"]
413521
v = m["v"]
522+
n_rep = m["n_rep"]
414523

415-
if repeat_kv:
416-
n_rep = m["n_rep"]
417-
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
418-
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
524+
# Apply repeat_kv to k and v
525+
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
526+
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
419527

420-
return _call_sdpa(
528+
return _call_attn(
421529
q,
422530
k,
423531
v,
424532
is_causal=is_causal,
425-
enable_gqa=enable_gqa,
426533
attn_mask=m.get("attn_mask"),
427534
dropout_p=m.get("dropout_p"),
428535
scale=m.get("scale"),
429536
)
430537

431-
# Replacement: torch_attention.default mirroring the positional signature exactly.
432-
# We do NOT pass enable_gqa here (it’s SDPA-only). We accept n_rep to mirror signature,
433-
# but we don’t need to use it in the replacement graph.
434538
def replacement_fn(*args):
435539
if len(args) != len(argnames):
436540
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
437541
m = dict(zip(argnames, args))
542+
# Replacement: just call torch_attention directly (no repeat_kv needed)
438543
return _call_attn(
439544
m["q"],
440545
m["k"],
@@ -445,37 +550,19 @@ def replacement_fn(*args):
445550
scale=m.get("scale"),
446551
)
447552

448-
# Pattern matcher needs to see explicit arg names
449553
_attach_signature(pattern_fn, argnames)
450554
_attach_signature(replacement_fn, argnames)
451-
452555
return pattern_fn, replacement_fn, argnames
453556

454557

455-
def generate_and_register_grouped_attn_patterns(
456-
patterns, register_ad_pattern: Callable, only_repeat_kv: bool = None
457-
):
558+
def generate_and_register_repeat_kv_torch_attn_patterns(patterns, register_ad_pattern: Callable):
458559
"""
459-
Auto-generate all grouped attention patterns across these axes:
460-
1) repeat_kv: [False, True]
461-
2) is_causal: [False, True]
462-
3) has_scale: [False, True]
463-
4) enable_gqa: [False, True] (only a kwarg to SDPA side)
464-
5) has_attn_mask: [False, True]
465-
6) has_dropout: [False, True]
466-
467-
Args:
468-
patterns: The ADPatternMatcherPass instance to register patterns to
469-
register_ad_pattern: The function to call to register each pattern
470-
only_repeat_kv: If True, only register patterns with repeat_kv=True.
471-
If False, only register patterns with repeat_kv=False.
472-
If None, register all patterns.
473-
474-
For each valid combo, we:
475-
- build pattern/replacement functions with exact-arg parity
476-
- build dummy args matching the signature (with CUDA fp16 tensors etc.)
477-
- build scalar_workaround dict for any scalars/n_rep present
478-
- call register_ad_pattern(...)
560+
Generate patterns for matching repeat_kv + torch_attention.
561+
Enumerates combinations across:
562+
- is_causal: [False, True]
563+
- has_scale: [False, True]
564+
- has_attn_mask: [False, True]
565+
- has_dropout: [False, True]
479566
"""
480567
q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
481568
k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
@@ -487,24 +574,15 @@ def generate_and_register_grouped_attn_patterns(
487574
n_rep_val = 7
488575

489576
total = 0
490-
axes = ((False, True),) * 6
491-
for repeat_kv, is_causal, has_scale, enable_gqa, has_attn_mask, has_dropout in product(*axes):
492-
if only_repeat_kv is not None:
493-
if only_repeat_kv and not repeat_kv:
494-
continue # Skip patterns without repeat_kv
495-
if not only_repeat_kv and repeat_kv:
496-
continue # Skip patterns with repeat_kv
497-
498-
pat_fn, rep_fn, argnames = make_grouped_attn_pair(
499-
repeat_kv=repeat_kv,
577+
axes = ((False, True),) * 4
578+
for is_causal, has_scale, has_attn_mask, has_dropout in product(*axes):
579+
pat_fn, rep_fn, argnames = make_repeat_kv_torch_attn_pair(
500580
is_causal=is_causal,
501581
has_scale=has_scale,
502-
enable_gqa=enable_gqa,
503582
has_attn_mask=has_attn_mask,
504583
has_dropout=has_dropout,
505584
)
506585

507-
# Build dummy args in the same positional order
508586
value_map = {
509587
"q": q,
510588
"k": k1,
@@ -539,12 +617,17 @@ def generate_and_register_grouped_attn_patterns(
539617
return total
540618

541619

542-
@TransformRegistry.register("match_grouped_attention_with_repeat_kv")
543-
class MatchGroupedAttentionWithRepeatKV(BaseTransform):
620+
@TransformRegistry.register("match_sdpa_to_torch_attention")
621+
class MatchSDPAToTorchAttention(BaseTransform):
544622
"""
545-
Match and replace grouped attention patterns WITH repeat_kv to
546-
torch.ops.auto_deploy.torch_attention.
623+
Match and replace SDPA patterns to torch.ops.auto_deploy.torch_attention.
547624
625+
This handles:
626+
- sdpa --> torch_attention
627+
- repeat_kv + sdpa --> torch_attention
628+
629+
This transform should run BEFORE match_repeat_kv_with_torch_attention to ensure
630+
SDPA calls are converted first.
548631
"""
549632

550633
def _apply(
@@ -554,32 +637,33 @@ def _apply(
554637
factory: ModelFactory,
555638
shared_config: SharedConfig,
556639
) -> Tuple[GraphModule, TransformInfo]:
557-
def register_grouped_attention_with_repeat_kv(patterns: ADPatternMatcherPass):
558-
return generate_and_register_grouped_attn_patterns(
559-
patterns, register_ad_pattern, only_repeat_kv=True
560-
)
640+
def register_sdpa_to_torch_attention(patterns: ADPatternMatcherPass):
641+
return generate_and_register_sdpa_to_torch_attn_patterns(patterns, register_ad_pattern)
561642

562-
num_grouped_patterns = _apply_pattern(
563-
gm, "Grouped Attention (with repeat_kv)", register_grouped_attention_with_repeat_kv
643+
num_patterns = _apply_pattern(
644+
gm, "SDPA to Torch Attention", register_sdpa_to_torch_attention
564645
)
565646

566647
info = TransformInfo(
567648
skipped=False,
568-
num_matches=num_grouped_patterns,
649+
num_matches=num_patterns,
569650
is_clean=False,
570651
has_valid_shapes=False,
571652
)
572653
return gm, info
573654

574655

575-
@TransformRegistry.register("match_grouped_attention_without_repeat_kv")
576-
class MatchGroupedAttentionWithoutRepeatKV(BaseTransform):
656+
@TransformRegistry.register("match_grouped_attention")
657+
class MatchRepeatKVWithTorchAttention(BaseTransform):
577658
"""
578-
Match and replace grouped attention patterns WITHOUT repeat_kv to
579-
torch.ops.auto_deploy.torch_attention.
659+
Match and replace repeat_kv + torch_attention patterns to torch_attention.
580660
581-
This transform should run AFTER match_grouped_attention_with_repeat_kv
582-
to avoid incorrectly matching patterns that should have repeat_kv.
661+
This handles:
662+
- repeat_kv + torch_attention --> torch_attention (removes redundant repeat_kv)
663+
- torch_attention --> torch_attention (identity, catches any remaining patterns)
664+
665+
This transform should run AFTER match_sdpa_to_torch_attention to ensure
666+
we match the repeat_kv + torch_attention pattern correctly.
583667
"""
584668

585669
def _apply(
@@ -589,26 +673,18 @@ def _apply(
589673
factory: ModelFactory,
590674
shared_config: SharedConfig,
591675
) -> Tuple[GraphModule, TransformInfo]:
592-
def register_grouped_attention_without_repeat_kv(patterns: ADPatternMatcherPass):
593-
return generate_and_register_grouped_attn_patterns(
594-
patterns, register_ad_pattern, only_repeat_kv=False
676+
def register_repeat_kv_with_torch_attention(patterns: ADPatternMatcherPass):
677+
return generate_and_register_repeat_kv_torch_attn_patterns(
678+
patterns, register_ad_pattern
595679
)
596680

597-
num_grouped_patterns = _apply_pattern(
598-
gm,
599-
"Grouped Attention (without repeat_kv)",
600-
register_grouped_attention_without_repeat_kv,
681+
num_patterns = _apply_pattern(
682+
gm, "Repeat KV with Torch Attention", register_repeat_kv_with_torch_attention
601683
)
602684

603-
if num_grouped_patterns == 0:
604-
ad_logger.warning(
605-
"Fail to find any Group Attention Pattern (without repeat_kv), "
606-
"output or performance may be incorrect"
607-
)
608-
609685
info = TransformInfo(
610686
skipped=False,
611-
num_matches=num_grouped_patterns,
687+
num_matches=num_patterns,
612688
is_clean=False,
613689
has_valid_shapes=False,
614690
)

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,10 @@ def _get_match_grouped_attention_optimizer() -> Callable:
446446
"cleanup_noop_slice": {
447447
"stage": "post_export",
448448
},
449-
"match_grouped_attention_with_repeat_kv": {
449+
"match_sdpa_to_torch_attention": {
450450
"stage": "pattern_matcher",
451451
},
452-
"match_grouped_attention_without_repeat_kv": {
452+
"match_grouped_attention": {
453453
"stage": "pattern_matcher",
454454
},
455455
}

0 commit comments

Comments
 (0)