Skip to content

Commit 84a303f

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Allow allocating transformed refs in run_scoped
PiperOrigin-RevId: 688448592
1 parent ebb75db commit 84a303f

File tree

5 files changed

+80
-16
lines changed

5 files changed

+80
-16
lines changed

jax/_src/pallas/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def get_array_aval(self) -> jax_core.ShapedArray:
200200
self.shape, dtype, memory_space=self.memory_space
201201
)
202202

203-
def get_ref_aval(self) -> AbstractMemoryRef:
203+
def get_ref_aval(self) -> TransformedRef | AbstractMemoryRef:
204204
# TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we
205205
# try to apply JAX ops to it.
206206
return AbstractMemoryRef(

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Contains GPU-specific Pallas abstractions."""
1616

17+
from __future__ import annotations
18+
1719
import abc
1820
import collections
1921
from collections.abc import Sequence
@@ -73,9 +75,32 @@ class GPUMemorySpace(enum.Enum):
7375
def __str__(self) -> str:
7476
return self.value
7577

76-
def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
78+
def __call__(
79+
self,
80+
shape: tuple[int, ...],
81+
dtype: jnp.dtype,
82+
transforms: Sequence[MemoryRefTransform] = (),
83+
):
7784
# A convenience function for constructing MemoryRef types.
78-
return pallas_core.MemoryRef(shape, dtype, memory_space=self)
85+
return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms)
86+
87+
88+
@dataclasses.dataclass(frozen=True)
89+
class GPUMemoryRef(pallas_core.MemoryRef):
90+
transforms: Sequence[MemoryRefTransform] = ()
91+
92+
def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef:
93+
aval = jax_core.ShapedArray(self.shape, self.dtype)
94+
for t in self.transforms:
95+
aval = t(aval)
96+
ref = pallas_core.TransformedRef(
97+
AbstractMemoryRef(aval, memory_space=self.memory_space), ()
98+
)
99+
for t in reversed(self.transforms):
100+
ref = t.undo(ref)
101+
if not ref.transforms:
102+
return ref.ref
103+
return ref
79104

80105

81106
class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC):

jax/_src/pallas/pallas_call.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,17 +1382,6 @@ def _ensure_2d_error_shape(arg):
13821382
return new_error, results
13831383
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
13841384

1385-
# All of those shenanigans are because we can't make TransformedRef a PyTree,
1386-
# because they should appear as atomic JAX values to the users.
1387-
@lu.transformation
1388-
def wrap_with_transforms(transforms, *args):
1389-
new_args = tuple(
1390-
state_types.TransformedRef(a, t) if t else a
1391-
for a, t in zip(args, transforms)
1392-
)
1393-
res = yield new_args, {}
1394-
yield res
1395-
13961385

13971386
@weakref_lru_cache
13981387
def _trace_kernel_to_jaxpr(
@@ -1410,7 +1399,9 @@ def _trace_kernel_to_jaxpr(
14101399
kernel_avals))
14111400
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
14121401
lu.wrap_init(fun), kernel_in_tree)
1413-
wrapped_kernel_fun = wrap_with_transforms(wrapped_kernel_fun, kernel_in_transforms)
1402+
wrapped_kernel_fun = primitives.wrap_with_transforms(
1403+
wrapped_kernel_fun, kernel_in_transforms
1404+
)
14141405
debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
14151406
with grid_mapping.trace_env():
14161407
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,

jax/_src/pallas/primitives.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from jax._src.pallas import core as pallas_core
4040
from jax._src.state import discharge as state_discharge
4141
from jax._src.state import indexing
42+
from jax._src.state import types as state_types
4243
from jax._src.state import primitives as sp
4344
from jax.interpreters import mlir
4445
import jax.numpy as jnp
@@ -816,6 +817,20 @@ def debug_print_lowering_rule(ctx, *args, **params):
816817
return result
817818

818819

820+
# All of those shenanigans are because we can't make TransformedRef a PyTree,
821+
# because they should appear as atomic JAX values to the users.
822+
# TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU
823+
# inferred by the compiler.
824+
@lu.transformation
825+
def wrap_with_transforms(transforms, *args):
826+
new_args = tuple(
827+
state_types.TransformedRef(a, t) if t else a
828+
for a, t in zip(args, transforms)
829+
)
830+
res = yield new_args, {}
831+
yield res
832+
833+
819834
run_scoped_p = jax_core.Primitive("run_scoped")
820835
run_scoped_p.multiple_results = True
821836

@@ -829,7 +844,17 @@ def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any:
829844
"""
830845
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
831846
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
832-
avals = [t.get_ref_aval() for t in flat_types]
847+
# We allow ref avals to be transformed references.
848+
ref_avals = [t.get_ref_aval() for t in flat_types]
849+
avals = [
850+
t.ref if isinstance(t, state_types.TransformedRef) else t
851+
for t in ref_avals
852+
]
853+
ref_transforms = tuple(
854+
t.transforms if isinstance(t, state_types.TransformedRef) else ()
855+
for t in ref_avals
856+
)
857+
flat_fun = wrap_with_transforms(flat_fun, ref_transforms)
833858
# Turn the function into a jaxpr. The body of run_scoped may have
834859
# effects (IO) on constvars (i.e. variables inherited from the
835860
# parent scope). Jax can't reason about effects to references that

tests/pallas/mosaic_gpu_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,29 @@ def kernel(x_ref, o_ref, barrier_ref):
287287
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
288288
np.testing.assert_array_equal(f(x), x)
289289

290+
def test_scoped_copy_with_transforms(self):
291+
ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))
292+
def kernel(x_ref, o_ref, barrier_ref):
293+
def body(tmp_ref):
294+
plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier=barrier_ref)
295+
plgpu.barrier_wait(barrier_ref)
296+
o_ref[...] = tmp_ref[...] * 2
297+
pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts))
298+
299+
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
300+
out_spec = plgpu.GPUBlockSpec(
301+
(128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM,
302+
)
303+
f = pl.pallas_call(
304+
kernel,
305+
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
306+
in_specs=(in_spec,),
307+
out_specs=out_spec,
308+
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
309+
)
310+
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
311+
np.testing.assert_array_equal(f(x), x * 2)
312+
290313
def test_copy_with_transforms_and_indexing(self):
291314
def kernel(x_ref, o_ref, barrier_ref):
292315
for i in range(2):

0 commit comments

Comments
 (0)