Skip to content

Commit 2db03ba

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Add support for grid dims in GPUMesh
Of course no communication can happen across grid dimensions (unlike over the WG dim), but we need to be able to launch multiple blocks somehow. PiperOrigin-RevId: 688488660
1 parent 0b3f0e1 commit 2db03ba

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -466,13 +466,22 @@ def __post_init__(self):
466466
"Requested too many CUDA threads per block. Each Mosaic thread"
467467
" corresponds to 128 CUDA threads."
468468
)
469+
if self.cluster:
470+
raise NotImplementedError(
471+
"Pallas/MosaicGPU does not support clusters yet."
472+
)
469473

470474
@property
471475
def shape(self):
472476
if self.num_threads is not None:
473-
pairs = zip(self.axis_names, (*self.grid, self.num_threads))
477+
pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads))
474478
else:
475-
pairs = (*zip(self.axis_names, self.grid), (_WARPGROUP_AXIS_NAME, 1))
479+
pairs = tuple(
480+
zip(
481+
(*self.axis_names, _WARPGROUP_AXIS_NAME),
482+
(*self.grid, *self.cluster, 1),
483+
)
484+
)
476485
return collections.OrderedDict(pairs)
477486

478487

@@ -485,11 +494,10 @@ def _gpu_mesh_discharge_rule(
485494
):
486495
del out_avals
487496
assert isinstance(mesh, GPUMesh)
488-
if mesh.grid or mesh.cluster:
497+
if mesh.cluster:
489498
raise NotImplementedError
490499
if mesh.num_threads is None:
491500
raise NotImplementedError
492-
threads_axis_name, num_threads = list(mesh.shape.items())[0]
493501
def body(*args):
494502
# Due to aliasing, args contains aliased inputs and outputs so we remove
495503
# outputs.
@@ -503,7 +511,7 @@ def body(*args):
503511
in_specs=[any_spec] * len(in_avals),
504512
out_specs=[any_spec] * len(in_avals),
505513
input_output_aliases={i: i for i in range(len(in_avals))},
506-
grid=((threads_axis_name, num_threads),),
514+
grid=tuple(mesh.shape.items()),
507515
)(*args)
508516
return out, ()
509517

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,6 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
11251125
if axis_name == grid_names[-1]:
11261126
return mgpu.warpgroup_idx(sync=False)
11271127
else:
1128-
raise NotImplementedError # The code below is untested
11291128
idx = grid_names.index(axis_name)
11301129
return arith_dialect.index_cast(
11311130
ir.IntegerType.get_signless(32),

tests/pallas/mosaic_gpu_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,28 @@ def kernel():
10251025
f(), np.repeat(np.arange(2), 128).reshape(2, 128)
10261026
)
10271027

1028+
def test_multiple_wg_with_grid(self):
1029+
mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg"))
1030+
1031+
@jax.jit
1032+
def f():
1033+
@pl.run_state
1034+
def inner(y_ref):
1035+
@pl.core_map(mesh)
1036+
def kernel():
1037+
xy_idx = jax.lax.axis_index(("x", "y"))
1038+
yx_idx = jax.lax.axis_index(("y", "x"))
1039+
wg_idx = jax.lax.axis_index("wg")
1040+
num_wgs = jax.lax.psum(1, "wg")
1041+
y_ref[xy_idx, wg_idx] = jnp.broadcast_to(
1042+
yx_idx * num_wgs + wg_idx, (128,)
1043+
)
1044+
y_init = jnp.zeros((4, 2, 128), np.int32)
1045+
return inner(y_init)
1046+
np.testing.assert_array_equal(
1047+
f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128)
1048+
)
1049+
10281050

10291051
if __name__ == "__main__":
10301052
absltest.main()

0 commit comments

Comments
 (0)