Skip to content

Commit f5d73b8

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Added test for custom pretty-printing rules
PiperOrigin-RevId: 745145207
1 parent b926fac commit f5d73b8

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -857,13 +857,14 @@ def _wgmma_ref_pp_eqn(
857857
acc, a, b, *leaves = eqn.invars
858858
a_transforms_treedef = eqn.params["a_transforms_tree"]
859859
b_transforms_treedef = eqn.params["b_transforms_tree"]
860+
split = getattr(a_transforms_treedef, "num_leaves", 0)
860861
a_transforms = (
861-
a_transforms_treedef.unflatten(leaves[: a_transforms_treedef.num_leaves])
862+
a_transforms_treedef.unflatten(leaves[:split])
862863
if a_transforms_treedef is not None
863864
else []
864865
)
865866
b_transforms = (
866-
b_transforms_treedef.unflatten(leaves[a_transforms_treedef.num_leaves :])
867+
b_transforms_treedef.unflatten(leaves[split:])
867868
if b_transforms_treedef is not None
868869
else []
869870
)

tests/pallas/mosaic_gpu_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2634,6 +2634,80 @@ class CoreMapWGTest(
26342634
...
26352635

26362636

2637+
class PrettyPrintingTest(PallasTest):
2638+
2639+
def test_load(self):
2640+
@functools.partial(
2641+
self.pallas_call,
2642+
out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32),
2643+
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
2644+
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM),
2645+
)
2646+
def kernel(x_ref, o_ref):
2647+
for i in range(2):
2648+
x = plgpu.load(x_ref, (i,))
2649+
o_ref[i, ...] = x
2650+
2651+
_ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((2, 128), jnp.float32)))
2652+
2653+
def test_copy_primitives(self):
2654+
num_steps = 4
2655+
2656+
@functools.partial(
2657+
self.pallas_call,
2658+
out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float32),
2659+
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
2660+
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
2661+
)
2662+
def kernel(x_gmem, o_gmem):
2663+
# ``plgpu.emit_pipeline`` is implemented in terms of async copy and
2664+
# synchronization primitives.
2665+
plgpu.emit_pipeline(
2666+
kernel_body,
2667+
in_specs=[pl.BlockSpec((64, 64), lambda i: (0, i))],
2668+
out_specs=[
2669+
pl.BlockSpec(
2670+
(64, 64),
2671+
lambda i: (0, i),
2672+
)
2673+
],
2674+
grid=(num_steps,),
2675+
max_concurrent_steps=2,
2676+
)(x_gmem, o_gmem)
2677+
2678+
def kernel_body(_, x_smem, o_smem):
2679+
o_smem[...] = x_smem[...] + 1.0
2680+
2681+
_ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((64, 64), jnp.float32)))
2682+
2683+
def test_wgmma(self):
2684+
transforms = ()
2685+
if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane:
2686+
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
2687+
2688+
@functools.partial(
2689+
self.pallas_call,
2690+
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
2691+
in_specs=[
2692+
plgpu.GPUBlockSpec(transforms=transforms),
2693+
plgpu.GPUBlockSpec(transforms=transforms),
2694+
],
2695+
)
2696+
def kernel(a_ref, b_ref, o_ref):
2697+
def scope(acc_ref):
2698+
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
2699+
return acc_ref[...]
2700+
2701+
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
2702+
2703+
_ = str(
2704+
jax.make_jaxpr(kernel)(
2705+
jax.ShapeDtypeStruct((64, 128), jnp.float16),
2706+
jax.ShapeDtypeStruct((128, 192), jnp.float16),
2707+
)
2708+
)
2709+
2710+
26372711
class ExamplesTest(PallasTest):
26382712

26392713
# Basic

0 commit comments

Comments
 (0)