1111
1212from ...models .factory import ModelFactory
1313from ...shim .interface import CachedSequenceInterface
14- from ...utils .logger import ad_logger
1514from ...utils .pattern_matcher import ADPatternMatcherPass , register_ad_pattern
1615from ..interface import (
1716 BaseTransform ,
@@ -374,28 +373,137 @@ def _call_attn(q, k, v, *, is_causal: bool, attn_mask=None, dropout_p=None, scal
374373 return torch .ops .auto_deploy .torch_attention .default (q , k , v , ** kwargs )
375374
376375
377- def make_grouped_attn_pair (
376+ def make_sdpa_to_torch_attn_pair (
378377 * ,
379- repeat_kv : bool ,
380378 is_causal : bool ,
381379 has_scale : bool ,
382380 enable_gqa : bool ,
383381 has_attn_mask : bool ,
384382 has_dropout : bool ,
385383) -> Tuple [Callable , Callable , List [str ]]:
386384 """
387- Returns (pattern_fn, replacement_fn, argnames) with exact positional parity.
388-
389- Arg order rules:
390- Base: (q, k, v)
391- +repeat_kv -> insert n_rep after (q, k, v)
392- +attn_mask -> include attn_mask after n_rep if repeat_kv else after (q, k, v)
393- +dropout -> include dropout_p after attn_mask or after n_rep/base if no attn_mask
394- +scale -> include scale last
385+ Returns (pattern_fn, replacement_fn, argnames) for matching SDPA to torch_attention.
386+
387+ Pattern: torch_attention_sdpa --> torch_attention
395388 """
396389 argnames : List [str ] = ["q" , "k" , "v" ]
397- if repeat_kv :
398- argnames .append ("n_rep" )
390+ if has_attn_mask :
391+ argnames .append ("attn_mask" )
392+ if has_dropout :
393+ argnames .append ("dropout_p" )
394+ if has_scale :
395+ argnames .append ("scale" )
396+
397+ def pattern_fn (* args ):
398+ if len (args ) != len (argnames ):
399+ raise TypeError (f"Expected { len (argnames )} args { tuple (argnames )} , got { len (args )} " )
400+ m = dict (zip (argnames , args ))
401+ return _call_sdpa (
402+ m ["q" ],
403+ m ["k" ],
404+ m ["v" ],
405+ is_causal = is_causal ,
406+ enable_gqa = enable_gqa ,
407+ attn_mask = m .get ("attn_mask" ),
408+ dropout_p = m .get ("dropout_p" ),
409+ scale = m .get ("scale" ),
410+ )
411+
412+ def replacement_fn (* args ):
413+ if len (args ) != len (argnames ):
414+ raise TypeError (f"Expected { len (argnames )} args { tuple (argnames )} , got { len (args )} " )
415+ m = dict (zip (argnames , args ))
416+ return _call_attn (
417+ m ["q" ],
418+ m ["k" ],
419+ m ["v" ],
420+ is_causal = is_causal ,
421+ attn_mask = m .get ("attn_mask" ),
422+ dropout_p = m .get ("dropout_p" ),
423+ scale = m .get ("scale" ),
424+ )
425+
426+ _attach_signature (pattern_fn , argnames )
427+ _attach_signature (replacement_fn , argnames )
428+ return pattern_fn , replacement_fn , argnames
429+
430+
431+ def generate_and_register_sdpa_to_torch_attn_patterns (patterns , register_ad_pattern : Callable ):
432+ """
433+ Generate patterns for matching SDPA to torch_attention.
434+ Enumerates combinations across:
435+ - is_causal: [False, True]
436+ - has_scale: [False, True]
437+ - enable_gqa: [False, True]
438+ - has_attn_mask: [False, True]
439+ - has_dropout: [False, True]
440+ """
441+ q = torch .randn (8 , 8 , 16 , 64 , device = "cuda" , dtype = torch .float16 )
442+ k = torch .randn (8 , 8 , 16 , 64 , device = "cuda" , dtype = torch .float16 )
443+ v = torch .randn (8 , 8 , 16 , 64 , device = "cuda" , dtype = torch .float16 )
444+ attn_mask_tensor = torch .randn (8 , 1 , 1 , 16 , device = "cuda" , dtype = torch .float16 )
445+
446+ dropout_val = 0.12345
447+ scale_val = 0.56789
448+
449+ total = 0
450+ axes = ((False , True ),) * 5
451+ for is_causal , has_scale , enable_gqa , has_attn_mask , has_dropout in product (* axes ):
452+ pat_fn , rep_fn , argnames = make_sdpa_to_torch_attn_pair (
453+ is_causal = is_causal ,
454+ has_scale = has_scale ,
455+ enable_gqa = enable_gqa ,
456+ has_attn_mask = has_attn_mask ,
457+ has_dropout = has_dropout ,
458+ )
459+
460+ value_map = {
461+ "q" : q ,
462+ "k" : k ,
463+ "v" : v ,
464+ "attn_mask" : attn_mask_tensor ,
465+ "dropout_p" : dropout_val ,
466+ "scale" : scale_val ,
467+ }
468+ dummy_args : List [object ] = []
469+ for name in argnames :
470+ try :
471+ dummy_args .append (value_map [name ])
472+ except KeyError :
473+ raise RuntimeError (f"Unexpected arg name: { name } " )
474+
475+ scalar_names = {"dropout_p" , "scale" }
476+ scalar_workaround : Dict [str , object ] = {
477+ n : value_map [n ] for n in argnames if n in scalar_names
478+ }
479+ if not scalar_workaround :
480+ scalar_workaround = None
481+
482+ register_ad_pattern (
483+ search_fn = pat_fn ,
484+ replace_fn = rep_fn ,
485+ patterns = patterns ,
486+ dummy_args = dummy_args ,
487+ scalar_workaround = scalar_workaround ,
488+ )
489+ total += 1
490+ return total
491+
492+
493+ def make_repeat_kv_torch_attn_pair (
494+ * ,
495+ is_causal : bool ,
496+ has_scale : bool ,
497+ has_attn_mask : bool ,
498+ has_dropout : bool ,
499+ ) -> Tuple [Callable , Callable , List [str ]]:
500+ """
501+ Returns (pattern_fn, replacement_fn, argnames) for matching repeat_kv + torch_attention.
502+
503+ Pattern: repeat_kv(k, n_rep), repeat_kv(v, n_rep), torch_attention --> torch_attention
504+ This handles GQA patterns where repeat_kv is explicitly applied before torch_attention.
505+ """
506+ argnames : List [str ] = ["q" , "k" , "v" , "n_rep" ]
399507 if has_attn_mask :
400508 argnames .append ("attn_mask" )
401509 if has_dropout :
@@ -411,30 +519,27 @@ def pattern_fn(*args):
411519 q = m ["q" ]
412520 k = m ["k" ]
413521 v = m ["v" ]
522+ n_rep = m ["n_rep" ]
414523
415- if repeat_kv :
416- n_rep = m ["n_rep" ]
417- k = torch .ops .auto_deploy .torch_attention_repeat_kv (k , n_rep )
418- v = torch .ops .auto_deploy .torch_attention_repeat_kv (v , n_rep )
524+ # Apply repeat_kv to k and v
525+ k = torch .ops .auto_deploy .torch_attention_repeat_kv (k , n_rep )
526+ v = torch .ops .auto_deploy .torch_attention_repeat_kv (v , n_rep )
419527
420- return _call_sdpa (
528+ return _call_attn (
421529 q ,
422530 k ,
423531 v ,
424532 is_causal = is_causal ,
425- enable_gqa = enable_gqa ,
426533 attn_mask = m .get ("attn_mask" ),
427534 dropout_p = m .get ("dropout_p" ),
428535 scale = m .get ("scale" ),
429536 )
430537
431- # Replacement: torch_attention.default mirroring the positional signature exactly.
432- # We do NOT pass enable_gqa here (it’s SDPA-only). We accept n_rep to mirror signature,
433- # but we don’t need to use it in the replacement graph.
434538 def replacement_fn (* args ):
435539 if len (args ) != len (argnames ):
436540 raise TypeError (f"Expected { len (argnames )} args { tuple (argnames )} , got { len (args )} " )
437541 m = dict (zip (argnames , args ))
542+ # Replacement: just call torch_attention directly (no repeat_kv needed)
438543 return _call_attn (
439544 m ["q" ],
440545 m ["k" ],
@@ -445,37 +550,19 @@ def replacement_fn(*args):
445550 scale = m .get ("scale" ),
446551 )
447552
448- # Pattern matcher needs to see explicit arg names
449553 _attach_signature (pattern_fn , argnames )
450554 _attach_signature (replacement_fn , argnames )
451-
452555 return pattern_fn , replacement_fn , argnames
453556
454557
455- def generate_and_register_grouped_attn_patterns (
456- patterns , register_ad_pattern : Callable , only_repeat_kv : bool = None
457- ):
558+ def generate_and_register_repeat_kv_torch_attn_patterns (patterns , register_ad_pattern : Callable ):
458559 """
459- Auto-generate all grouped attention patterns across these axes:
460- 1) repeat_kv: [False, True]
461- 2) is_causal: [False, True]
462- 3) has_scale: [False, True]
463- 4) enable_gqa: [False, True] (only a kwarg to SDPA side)
464- 5) has_attn_mask: [False, True]
465- 6) has_dropout: [False, True]
466-
467- Args:
468- patterns: The ADPatternMatcherPass instance to register patterns to
469- register_ad_pattern: The function to call to register each pattern
470- only_repeat_kv: If True, only register patterns with repeat_kv=True.
471- If False, only register patterns with repeat_kv=False.
472- If None, register all patterns.
473-
474- For each valid combo, we:
475- - build pattern/replacement functions with exact-arg parity
476- - build dummy args matching the signature (with CUDA fp16 tensors etc.)
477- - build scalar_workaround dict for any scalars/n_rep present
478- - call register_ad_pattern(...)
560+ Generate patterns for matching repeat_kv + torch_attention.
561+ Enumerates combinations across:
562+ - is_causal: [False, True]
563+ - has_scale: [False, True]
564+ - has_attn_mask: [False, True]
565+ - has_dropout: [False, True]
479566 """
480567 q = torch .randn (8 , 8 , 16 , 64 , device = "cuda" , dtype = torch .float16 )
481568 k1 = torch .randn (8 , 1 , 16 , 64 , device = "cuda" , dtype = torch .float16 )
@@ -487,24 +574,15 @@ def generate_and_register_grouped_attn_patterns(
487574 n_rep_val = 7
488575
489576 total = 0
490- axes = ((False , True ),) * 6
491- for repeat_kv , is_causal , has_scale , enable_gqa , has_attn_mask , has_dropout in product (* axes ):
492- if only_repeat_kv is not None :
493- if only_repeat_kv and not repeat_kv :
494- continue # Skip patterns without repeat_kv
495- if not only_repeat_kv and repeat_kv :
496- continue # Skip patterns with repeat_kv
497-
498- pat_fn , rep_fn , argnames = make_grouped_attn_pair (
499- repeat_kv = repeat_kv ,
577+ axes = ((False , True ),) * 4
578+ for is_causal , has_scale , has_attn_mask , has_dropout in product (* axes ):
579+ pat_fn , rep_fn , argnames = make_repeat_kv_torch_attn_pair (
500580 is_causal = is_causal ,
501581 has_scale = has_scale ,
502- enable_gqa = enable_gqa ,
503582 has_attn_mask = has_attn_mask ,
504583 has_dropout = has_dropout ,
505584 )
506585
507- # Build dummy args in the same positional order
508586 value_map = {
509587 "q" : q ,
510588 "k" : k1 ,
@@ -539,12 +617,17 @@ def generate_and_register_grouped_attn_patterns(
539617 return total
540618
541619
542- @TransformRegistry .register ("match_grouped_attention_with_repeat_kv " )
543- class MatchGroupedAttentionWithRepeatKV (BaseTransform ):
620+ @TransformRegistry .register ("match_sdpa_to_torch_attention " )
621+ class MatchSDPAToTorchAttention (BaseTransform ):
544622 """
545- Match and replace grouped attention patterns WITH repeat_kv to
546- torch.ops.auto_deploy.torch_attention.
623+ Match and replace SDPA patterns to torch.ops.auto_deploy.torch_attention.
547624
625+ This handles:
626+ - sdpa --> torch_attention
627+ - repeat_kv + sdpa --> torch_attention
628+
629+ This transform should run BEFORE match_repeat_kv_with_torch_attention to ensure
630+ SDPA calls are converted first.
548631 """
549632
550633 def _apply (
@@ -554,32 +637,33 @@ def _apply(
554637 factory : ModelFactory ,
555638 shared_config : SharedConfig ,
556639 ) -> Tuple [GraphModule , TransformInfo ]:
557- def register_grouped_attention_with_repeat_kv (patterns : ADPatternMatcherPass ):
558- return generate_and_register_grouped_attn_patterns (
559- patterns , register_ad_pattern , only_repeat_kv = True
560- )
640+ def register_sdpa_to_torch_attention (patterns : ADPatternMatcherPass ):
641+ return generate_and_register_sdpa_to_torch_attn_patterns (patterns , register_ad_pattern )
561642
562- num_grouped_patterns = _apply_pattern (
563- gm , "Grouped Attention (with repeat_kv) " , register_grouped_attention_with_repeat_kv
643+ num_patterns = _apply_pattern (
644+ gm , "SDPA to Torch Attention " , register_sdpa_to_torch_attention
564645 )
565646
566647 info = TransformInfo (
567648 skipped = False ,
568- num_matches = num_grouped_patterns ,
649+ num_matches = num_patterns ,
569650 is_clean = False ,
570651 has_valid_shapes = False ,
571652 )
572653 return gm , info
573654
574655
575- @TransformRegistry .register ("match_grouped_attention_without_repeat_kv " )
576- class MatchGroupedAttentionWithoutRepeatKV (BaseTransform ):
656+ @TransformRegistry .register ("match_grouped_attention " )
657+ class MatchRepeatKVWithTorchAttention (BaseTransform ):
577658 """
578- Match and replace grouped attention patterns WITHOUT repeat_kv to
579- torch.ops.auto_deploy.torch_attention.
659+ Match and replace repeat_kv + torch_attention patterns to torch_attention.
580660
581- This transform should run AFTER match_grouped_attention_with_repeat_kv
582- to avoid incorrectly matching patterns that should have repeat_kv.
661+ This handles:
662+ - repeat_kv + torch_attention --> torch_attention (removes redundant repeat_kv)
663+ - torch_attention --> torch_attention (identity, catches any remaining patterns)
664+
665+ This transform should run AFTER match_sdpa_to_torch_attention to ensure
666+ we match the repeat_kv + torch_attention pattern correctly.
583667 """
584668
585669 def _apply (
@@ -589,26 +673,18 @@ def _apply(
589673 factory : ModelFactory ,
590674 shared_config : SharedConfig ,
591675 ) -> Tuple [GraphModule , TransformInfo ]:
592- def register_grouped_attention_without_repeat_kv (patterns : ADPatternMatcherPass ):
593- return generate_and_register_grouped_attn_patterns (
594- patterns , register_ad_pattern , only_repeat_kv = False
676+ def register_repeat_kv_with_torch_attention (patterns : ADPatternMatcherPass ):
677+ return generate_and_register_repeat_kv_torch_attn_patterns (
678+ patterns , register_ad_pattern
595679 )
596680
597- num_grouped_patterns = _apply_pattern (
598- gm ,
599- "Grouped Attention (without repeat_kv)" ,
600- register_grouped_attention_without_repeat_kv ,
681+ num_patterns = _apply_pattern (
682+ gm , "Repeat KV with Torch Attention" , register_repeat_kv_with_torch_attention
601683 )
602684
603- if num_grouped_patterns == 0 :
604- ad_logger .warning (
605- "Fail to find any Group Attention Pattern (without repeat_kv), "
606- "output or performance may be incorrect"
607- )
608-
609685 info = TransformInfo (
610686 skipped = False ,
611- num_matches = num_grouped_patterns ,
687+ num_matches = num_patterns ,
612688 is_clean = False ,
613689 has_valid_shapes = False ,
614690 )
0 commit comments