File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change 2222from jax ._src import core as jax_core
2323from jax ._src import tree_util
2424from jax ._src import typing as jax_typing
25+ from jax ._src import util
2526import jax ._src .lax as lax
2627from jax ._src .lax .control_flow import conditionals
2728from jax ._src .pallas import core as pl_core
@@ -132,6 +133,7 @@ def _make_kernel(body,
132133 out_shape : object ,
133134 mesh : pl_core .Mesh ,
134135 scratch_shapes : pl_core .ScratchShapeTree = (),
136+ name : str | None = None ,
135137 ** mesh_kwargs
136138 ):
137139 if unwrap_out := not isinstance (out_shape , (tuple , list )):
@@ -155,7 +157,8 @@ def wrapper(*operands):
155157 out_shape ,
156158 )
157159
158- @pl_core .core_map (mesh , ** mesh_kwargs )
160+
161+ @pl_core .core_map (mesh , ** mesh_kwargs , name = name or util .fun_name (body ))
159162 def _ ():
160163 return pl_primitives .run_scoped (
161164 functools .partial (body , * arg_refs , * out_refs ),
You can’t perform that action at this time.
0 commit comments