Skip to content

Commit 2fbc76d

Browse files
brianwa84Google-ML-Automation
authored andcommitted
Use util.fun_name(body) instead of "_" as the default name for the core_map in pl.kernel.
PiperOrigin-RevId: 835137081
1 parent f8c0b9a commit 2fbc76d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

jax/_src/pallas/helpers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from jax._src import core as jax_core
2323
from jax._src import tree_util
2424
from jax._src import typing as jax_typing
25+
from jax._src import util
2526
import jax._src.lax as lax
2627
from jax._src.lax.control_flow import conditionals
2728
from jax._src.pallas import core as pl_core
@@ -132,6 +133,7 @@ def _make_kernel(body,
132133
out_shape: object,
133134
mesh: pl_core.Mesh,
134135
scratch_shapes: pl_core.ScratchShapeTree = (),
136+
name: str | None = None,
135137
**mesh_kwargs
136138
):
137139
if unwrap_out := not isinstance(out_shape, (tuple, list)):
@@ -155,7 +157,8 @@ def wrapper(*operands):
155157
out_shape,
156158
)
157159

158-
@pl_core.core_map(mesh, **mesh_kwargs)
160+
161+
@pl_core.core_map(mesh, **mesh_kwargs, name=name or util.fun_name(body))
159162
def _():
160163
return pl_primitives.run_scoped(
161164
functools.partial(body, *arg_refs, *out_refs),

0 commit comments

Comments
 (0)