Skip to content

Commit f847350

Browse files
superbobryjax authors
authored andcommitted
Removed kernel_regeneration_util from Mosaic
It was only used for persisting kernel metadata, and that can be done via jax.named_scope instead. PiperOrigin-RevId: 642195336
1 parent 11370b7 commit f847350

File tree

7 files changed

+52
-150
lines changed

7 files changed

+52
-150
lines changed

jax/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,6 @@ pytype_strict_library(
626626
":pallas", # build_cleaner: keep
627627
":tpu_custom_call",
628628
"//jax/_src/pallas/mosaic:core",
629-
"//jax/_src/pallas/mosaic:kernel_regeneration_util",
630629
"//jax/_src/pallas/mosaic:lowering",
631630
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
632631
"//jax/_src/pallas/mosaic:pipeline",

jax/_src/pallas/mosaic/BUILD

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,6 @@ py_library(
7777
] + py_deps("numpy"),
7878
)
7979

80-
py_library(
81-
name = "kernel_regeneration_util",
82-
srcs = ["kernel_regeneration_util.py"],
83-
deps = [
84-
"//jax:mlir",
85-
],
86-
)
87-
8880
py_library(
8981
name = "pipeline",
9082
srcs = ["pipeline.py"],

jax/_src/pallas/mosaic/kernel_regeneration_util.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

jax/_src/pallas/mosaic/pallas_call_registration.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,6 @@ def pallas_call_tpu_lowering_rule(
7878
mlir_ctx.load_all_available_dialects()
7979
tpu.register_dialect(mlir_ctx)
8080
dimension_semantics = mosaic_params.get("dimension_semantics", None)
81-
kernel_regeneration_metadata = mosaic_params.get(
82-
"kernel_regeneration_metadata"
83-
)
8481
mosaic_module, extra_args = lowering.lower_jaxpr_to_module(
8582
mlir_ctx, grid_mapping, in_shapes, out_shapes, jaxpr,
8683
dimension_semantics=dimension_semantics, mesh=mesh)
@@ -101,11 +98,10 @@ def _lower_fun(*args):
10198
out_avals,
10299
backend=ctx.module_context.backend,
103100
kernel_name=name,
104-
kernel_regeneration_metadata=kernel_regeneration_metadata,
105-
cost_estimate=mosaic_params.get("cost_estimate", None),
106-
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes", None),
107-
flags=mosaic_params.get("flags", None),
108-
allow_input_fusion=mosaic_params.get("allow_input_fusion", None),
101+
cost_estimate=mosaic_params.get("cost_estimate"),
102+
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
103+
flags=mosaic_params.get("flags"),
104+
allow_input_fusion=mosaic_params.get("allow_input_fusion"),
109105
input_output_aliases=input_output_aliases,
110106
)(
111107
*dynamic_grid_args,

jax/_src/tpu_custom_call.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def _tpu_custom_call_lowering(
195195
*in_nodes, # pylint: disable=missing-function-docstring
196196
config: CustomCallBackendConfig,
197197
kernel_name: str | None,
198-
kernel_regeneration_metadata: bytes | None,
199198
out_avals: Any,
200199
input_output_aliases: tuple[tuple[int, int], ...],
201200
) -> ...:
@@ -248,15 +247,11 @@ def _tpu_custom_call_lowering(
248247
]),
249248
)
250249

251-
# Add kernel_name and kernel_regeneration_metadata as attributes to the
252-
# custom call op. This is because we do not want to pollute the backend_config
253-
# with this information.
250+
# Add kernel_name and kernel_metadata as attributes to the custom call op.
251+
# This is because we do not want to pollute the backend_config with this
252+
# information.
254253
if kernel_name is not None:
255254
call.attributes["kernel_name"] = ir.StringAttr.get(kernel_name)
256-
if kernel_regeneration_metadata is not None:
257-
call.attributes["kernel_regeneration_metadata"] = ir.StringAttr.get(
258-
base64.b64encode(kernel_regeneration_metadata)
259-
)
260255
if multiple_results:
261256
results = [stablehlo.get_tuple_element(call, mlir.i32_attr(i))
262257
for i in range(len(out_avals))]
@@ -376,7 +371,6 @@ def as_tpu_kernel(
376371
backend: str | xla_client.Client = "tpu",
377372
device_type: str | None = None,
378373
kernel_name: str | None = None,
379-
kernel_regeneration_metadata: bytes | None = None,
380374
vmem_limit_bytes: int | None = None,
381375
flags: dict[str, bool | int | float] | None = None,
382376
allow_input_fusion: list[bool] | None = None,
@@ -435,7 +429,6 @@ def as_tpu_kernel(
435429
has_communication=has_communication,
436430
has_custom_barrier=has_custom_barrier,
437431
kernel_name=kernel_name,
438-
kernel_regeneration_metadata=kernel_regeneration_metadata,
439432
cost_estimate=cost_estimate,
440433
vmem_limit_bytes=vmem_limit_bytes,
441434
flags=flags,
@@ -455,7 +448,6 @@ def _lowered_as_tpu_kernel(
455448
has_communication: bool = False,
456449
has_custom_barrier: bool = False,
457450
kernel_name: str | None = None,
458-
kernel_regeneration_metadata: bytes | None = None,
459451
vmem_limit_bytes: int | None = None,
460452
flags: dict[str, bool | int | float] | None = None,
461453
allow_input_fusion: list[bool] | None = None,
@@ -496,7 +488,6 @@ def apply_kernel(*args, collective_id: int | None = None):
496488
*args,
497489
config=config,
498490
kernel_name=kernel_name,
499-
kernel_regeneration_metadata=kernel_regeneration_metadata,
500491
out_avals=out_avals,
501492
input_output_aliases=input_output_aliases,
502493
)

jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py

Lines changed: 45 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
from __future__ import annotations
1818

19+
from collections.abc import Mapping
1920
import dataclasses
2021
import enum
2122
import functools
22-
from typing import Any, Callable, Literal, NamedTuple, Union, Optional, overload
23+
from typing import Any, Callable, Literal, NamedTuple, Optional, Union, overload
2324

2425
import jax
2526
from jax import ad_checkpoint
@@ -89,10 +90,13 @@ class SegmentIds(NamedTuple):
8990

9091

9192
def get_kernel_name(
92-
is_mqa: bool, save_residuals: bool, is_segmented: bool, phase: str
93+
block_metadata: Mapping[str, Any],
94+
is_mqa: bool,
95+
save_residuals: bool,
96+
is_segmented: bool,
97+
phase: str,
9398
) -> str:
9499
"""Returns a unique name for all SplashAttention kernel variants."""
95-
96100
assert phase == "dq" or phase == "dkv" or phase == "fwd"
97101
# Saving residuals is supported only for the fwd phase.
98102
assert not save_residuals or phase == "fwd"
@@ -103,7 +107,9 @@ def get_kernel_name(
103107
residuals = "_no_residuals"
104108
attention_type = "mqa" if is_mqa else "mha"
105109
segments = "_segmented" if is_segmented else ""
106-
return f"splash_{attention_type}_{phase}{segments}{residuals}"
110+
return f"splash_{attention_type}_{phase}{segments}{residuals}_" + "_".join(
111+
f"{k}={v}" for k, v in sorted(block_metadata.items())
112+
)
107113

108114

109115
# Reference attention implementations
@@ -1054,28 +1060,17 @@ def logsumexp_index_map(h, i, *_):
10541060
out_shapes += [None]
10551061
out_specs += [None]
10561062

1057-
# Attach useful metadata to the custom-call HLO op.
1058-
# Having this information available in an HLO-dump or xprof is valuable for
1059-
# debugging and performance investigation.
1060-
metadata_dict = dict(
1061-
block_sizes=dataclasses.asdict(block_sizes),
1062-
is_mqa=is_mqa,
1063-
save_residuals=save_residuals,
1064-
mask_value=mask_value,
1065-
is_segmented=segment_ids is not None,
1066-
attn_logits_soft_cap=attn_logits_soft_cap,
1067-
residual_checkpoint_name=residual_checkpoint_name,
1068-
)
1069-
1070-
mosaic_params = pltpu.encode_kernel_regeneration_metadata(metadata_dict)
1071-
1072-
mosaic_params.update(
1063+
mosaic_params = dict(
10731064
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
10741065
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True},
10751066
)
10761067

10771068
kernel_name = get_kernel_name(
1078-
is_mqa, save_residuals, segment_ids is not None, "fwd"
1069+
dataclasses.asdict(block_sizes),
1070+
is_mqa=is_mqa,
1071+
save_residuals=save_residuals,
1072+
is_segmented=segment_ids is not None,
1073+
phase="fwd",
10791074
)
10801075

10811076
if fwd_mask_info.data_next is not None:
@@ -1526,28 +1521,24 @@ def logsumexp_index_map(h, i, *_):
15261521
)
15271522
num_scalar_prefetch = 3
15281523

1529-
# Attach useful metadata to the custom-call HLO op.
1530-
# Having this information available in an HLO-dump or xprof is valuable for
1531-
# debugging and performance investigation.
1532-
metadata_dict = dict(
1533-
block_q_dq=bq,
1534-
block_kv_dq=bkv,
1535-
q_layout=q_layout,
1536-
k_layout=k_layout,
1537-
v_layout=v_layout,
1538-
is_mqa=is_mqa,
1539-
mask_value=mask_value,
1540-
is_segmented=segment_ids is not None,
1541-
attn_logits_soft_cap=attn_logits_soft_cap,
1542-
)
1543-
1544-
mosaic_params = pltpu.encode_kernel_regeneration_metadata(metadata_dict)
1545-
mosaic_params.update(
1524+
mosaic_params = dict(
15461525
dimension_semantics=("arbitrary", "arbitrary", "arbitrary"),
15471526
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True},
15481527
)
15491528

1550-
kernel_name = get_kernel_name(is_mqa, False, segment_ids is not None, "dq")
1529+
kernel_name = get_kernel_name(
1530+
dict(
1531+
block_q_dq=bq,
1532+
block_kv_dq=bkv,
1533+
q_layout=q_layout,
1534+
k_layout=k_layout,
1535+
v_layout=v_layout,
1536+
),
1537+
is_mqa=is_mqa,
1538+
save_residuals=False,
1539+
is_segmented=segment_ids is not None,
1540+
phase="dq",
1541+
)
15511542
with jax.named_scope(kernel_name):
15521543
_, dq = pl.pallas_call(
15531544
kernel,
@@ -2072,35 +2063,30 @@ def logsumexp_index_map(
20722063
)
20732064
num_scalar_prefetch = 3
20742065

2075-
# Attach useful metadata to the custom-call HLO op.
2076-
# Having this information available in an HLO-dump or xprof is valuable for
2077-
# debugging and performance investigation.
2078-
metadata_dict = dict(
2079-
block_q_dkv=bq,
2080-
block_kv_dkv=bkv,
2081-
block_kv_dkv_compute=bkv_compute,
2082-
q_layout=q_layout,
2083-
k_layout=k_layout,
2084-
v_layout=v_layout,
2085-
use_fused_bwd_kernel=use_fused_bwd_kernel,
2086-
is_mqa=is_mqa,
2087-
mask_value=mask_value,
2088-
is_segmented=segment_ids is not None,
2089-
attn_logits_soft_cap=attn_logits_soft_cap,
2090-
)
2091-
2092-
mosaic_params = pltpu.encode_kernel_regeneration_metadata(metadata_dict)
20932066
# We set all dimensions to arbitrary because:
20942067
# 1) for kv_seq_len, the splash attention prefetch schedule assumes no
20952068
# megacore
20962069
# 2) for heads, we are reducing over heads
20972070
# 3) for q_seq_len, we are reducing over it to compute dkv
2098-
mosaic_params.update(
2071+
mosaic_params = dict(
20992072
dimension_semantics=("arbitrary", "arbitrary", "arbitrary"),
21002073
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True},
21012074
)
21022075

2103-
kernel_name = get_kernel_name(is_mqa, False, segment_ids is not None, "dkv")
2076+
kernel_name = get_kernel_name(
2077+
dict(
2078+
block_q_dkv=bq,
2079+
block_kv_dkv=bkv,
2080+
block_kv_dkv_compute=bkv_compute,
2081+
q_layout=q_layout,
2082+
k_layout=k_layout,
2083+
v_layout=v_layout,
2084+
),
2085+
is_mqa=is_mqa,
2086+
save_residuals=False,
2087+
is_segmented=segment_ids is not None,
2088+
phase="dkv",
2089+
)
21042090
with jax.named_scope(kernel_name):
21052091
_, _, _, dq_unreduced, dk, dv = pl.pallas_call(
21062092
kernel,

jax/experimental/pallas/tpu.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
from jax._src.pallas.mosaic.core import semaphore
2121
from jax._src.pallas.mosaic.core import SemaphoreType
2222
from jax._src.pallas.mosaic.core import TPUMemorySpace
23-
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
24-
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
2523
from jax._src.pallas.mosaic.lowering import LoweringException
2624
from jax._src.pallas.mosaic.pipeline import BufferedRef
2725
from jax._src.pallas.mosaic.pipeline import emit_pipeline

0 commit comments

Comments
 (0)