16
16
17
17
from __future__ import annotations
18
18
19
+ from collections .abc import Mapping
19
20
import dataclasses
20
21
import enum
21
22
import functools
22
- from typing import Any , Callable , Literal , NamedTuple , Union , Optional , overload
23
+ from typing import Any , Callable , Literal , NamedTuple , Optional , Union , overload
23
24
24
25
import jax
25
26
from jax import ad_checkpoint
@@ -89,10 +90,13 @@ class SegmentIds(NamedTuple):
89
90
90
91
91
92
def get_kernel_name (
92
- is_mqa : bool , save_residuals : bool , is_segmented : bool , phase : str
93
+ block_metadata : Mapping [str , Any ],
94
+ is_mqa : bool ,
95
+ save_residuals : bool ,
96
+ is_segmented : bool ,
97
+ phase : str ,
93
98
) -> str :
94
99
"""Returns a unique name for all SplashAttention kernel variants."""
95
-
96
100
assert phase == "dq" or phase == "dkv" or phase == "fwd"
97
101
# Saving residuals is supported only for the fwd phase.
98
102
assert not save_residuals or phase == "fwd"
@@ -103,7 +107,9 @@ def get_kernel_name(
103
107
residuals = "_no_residuals"
104
108
attention_type = "mqa" if is_mqa else "mha"
105
109
segments = "_segmented" if is_segmented else ""
106
- return f"splash_{ attention_type } _{ phase } { segments } { residuals } "
110
+ return f"splash_{ attention_type } _{ phase } { segments } { residuals } _" + "_" .join (
111
+ f"{ k } ={ v } " for k , v in sorted (block_metadata .items ())
112
+ )
107
113
108
114
109
115
# Reference attention implementations
@@ -1054,28 +1060,17 @@ def logsumexp_index_map(h, i, *_):
1054
1060
out_shapes += [None ]
1055
1061
out_specs += [None ]
1056
1062
1057
- # Attach useful metadata to the custom-call HLO op.
1058
- # Having this information available in an HLO-dump or xprof is valuable for
1059
- # debugging and performance investigation.
1060
- metadata_dict = dict (
1061
- block_sizes = dataclasses .asdict (block_sizes ),
1062
- is_mqa = is_mqa ,
1063
- save_residuals = save_residuals ,
1064
- mask_value = mask_value ,
1065
- is_segmented = segment_ids is not None ,
1066
- attn_logits_soft_cap = attn_logits_soft_cap ,
1067
- residual_checkpoint_name = residual_checkpoint_name ,
1068
- )
1069
-
1070
- mosaic_params = pltpu .encode_kernel_regeneration_metadata (metadata_dict )
1071
-
1072
- mosaic_params .update (
1063
+ mosaic_params = dict (
1073
1064
dimension_semantics = ("parallel" , "arbitrary" , "arbitrary" ),
1074
1065
flags = {"XLA_TPU_FORCE_LP_LLO_SCHEDULER" : True },
1075
1066
)
1076
1067
1077
1068
kernel_name = get_kernel_name (
1078
- is_mqa , save_residuals , segment_ids is not None , "fwd"
1069
+ dataclasses .asdict (block_sizes ),
1070
+ is_mqa = is_mqa ,
1071
+ save_residuals = save_residuals ,
1072
+ is_segmented = segment_ids is not None ,
1073
+ phase = "fwd" ,
1079
1074
)
1080
1075
1081
1076
if fwd_mask_info .data_next is not None :
@@ -1526,28 +1521,24 @@ def logsumexp_index_map(h, i, *_):
1526
1521
)
1527
1522
num_scalar_prefetch = 3
1528
1523
1529
- # Attach useful metadata to the custom-call HLO op.
1530
- # Having this information available in an HLO-dump or xprof is valuable for
1531
- # debugging and performance investigation.
1532
- metadata_dict = dict (
1533
- block_q_dq = bq ,
1534
- block_kv_dq = bkv ,
1535
- q_layout = q_layout ,
1536
- k_layout = k_layout ,
1537
- v_layout = v_layout ,
1538
- is_mqa = is_mqa ,
1539
- mask_value = mask_value ,
1540
- is_segmented = segment_ids is not None ,
1541
- attn_logits_soft_cap = attn_logits_soft_cap ,
1542
- )
1543
-
1544
- mosaic_params = pltpu .encode_kernel_regeneration_metadata (metadata_dict )
1545
- mosaic_params .update (
1524
+ mosaic_params = dict (
1546
1525
dimension_semantics = ("arbitrary" , "arbitrary" , "arbitrary" ),
1547
1526
flags = {"XLA_TPU_FORCE_LP_LLO_SCHEDULER" : True },
1548
1527
)
1549
1528
1550
- kernel_name = get_kernel_name (is_mqa , False , segment_ids is not None , "dq" )
1529
+ kernel_name = get_kernel_name (
1530
+ dict (
1531
+ block_q_dq = bq ,
1532
+ block_kv_dq = bkv ,
1533
+ q_layout = q_layout ,
1534
+ k_layout = k_layout ,
1535
+ v_layout = v_layout ,
1536
+ ),
1537
+ is_mqa = is_mqa ,
1538
+ save_residuals = False ,
1539
+ is_segmented = segment_ids is not None ,
1540
+ phase = "dq" ,
1541
+ )
1551
1542
with jax .named_scope (kernel_name ):
1552
1543
_ , dq = pl .pallas_call (
1553
1544
kernel ,
@@ -2072,35 +2063,30 @@ def logsumexp_index_map(
2072
2063
)
2073
2064
num_scalar_prefetch = 3
2074
2065
2075
- # Attach useful metadata to the custom-call HLO op.
2076
- # Having this information available in an HLO-dump or xprof is valuable for
2077
- # debugging and performance investigation.
2078
- metadata_dict = dict (
2079
- block_q_dkv = bq ,
2080
- block_kv_dkv = bkv ,
2081
- block_kv_dkv_compute = bkv_compute ,
2082
- q_layout = q_layout ,
2083
- k_layout = k_layout ,
2084
- v_layout = v_layout ,
2085
- use_fused_bwd_kernel = use_fused_bwd_kernel ,
2086
- is_mqa = is_mqa ,
2087
- mask_value = mask_value ,
2088
- is_segmented = segment_ids is not None ,
2089
- attn_logits_soft_cap = attn_logits_soft_cap ,
2090
- )
2091
-
2092
- mosaic_params = pltpu .encode_kernel_regeneration_metadata (metadata_dict )
2093
2066
# We set all dimensions to arbitrary because:
2094
2067
# 1) for kv_seq_len, the splash attention prefetch schedule assumes no
2095
2068
# megacore
2096
2069
# 2) for heads, we are reducing over heads
2097
2070
# 3) for q_seq_len, we are reducing over it to compute dkv
2098
- mosaic_params . update (
2071
+ mosaic_params = dict (
2099
2072
dimension_semantics = ("arbitrary" , "arbitrary" , "arbitrary" ),
2100
2073
flags = {"XLA_TPU_FORCE_LP_LLO_SCHEDULER" : True },
2101
2074
)
2102
2075
2103
- kernel_name = get_kernel_name (is_mqa , False , segment_ids is not None , "dkv" )
2076
+ kernel_name = get_kernel_name (
2077
+ dict (
2078
+ block_q_dkv = bq ,
2079
+ block_kv_dkv = bkv ,
2080
+ block_kv_dkv_compute = bkv_compute ,
2081
+ q_layout = q_layout ,
2082
+ k_layout = k_layout ,
2083
+ v_layout = v_layout ,
2084
+ ),
2085
+ is_mqa = is_mqa ,
2086
+ save_residuals = False ,
2087
+ is_segmented = segment_ids is not None ,
2088
+ phase = "dkv" ,
2089
+ )
2104
2090
with jax .named_scope (kernel_name ):
2105
2091
_ , _ , _ , dq_unreduced , dk , dv = pl .pallas_call (
2106
2092
kernel ,
0 commit comments