Skip to content

Commit 784ebea

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Automatically squash a >3D logical grid into a 3D physical CUDA grid.
PiperOrigin-RevId: 702013252
1 parent 6b02950 commit 784ebea

File tree

2 files changed

+143
-17
lines changed

2 files changed

+143
-17
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 84 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ class ModuleContext:
201201
]
202202
name_stack: source_info_util.NameStack
203203
traceback_caches: mlir.TracebackCaches
204+
squashed_dims: tuple[int, ...]
204205

205206
def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef:
206207
"""Reserves a barrier.
@@ -403,12 +404,15 @@ def lower_jaxpr_to_module(
403404
parallel_grid = [
404405
d for i, d in enumerate(logical_grid) if i not in sequential_axes
405406
]
406-
if len(parallel_grid) < 3:
407+
if len(parallel_grid) <= 3:
408+
squashed_dims = ()
407409
parallel_grid += (1,) * (3 - len(parallel_grid))
408-
elif len(parallel_grid) > 3:
409-
raise NotImplementedError(
410-
"Only <=3D grids are supported in Mosaic GPU lowering."
411-
)
410+
else:
411+
# If we have >3 parallel dimensions, we merge all leading dimensions
412+
# into the first (Dimension.x) CUDA grid dimension.
413+
squashed_dims = parallel_grid[:-2]
414+
parallel_grid = [math.prod(parallel_grid[:-2]), *parallel_grid[-2:]]
415+
412416
if sequential_axes:
413417
# TODO(slebedev): Support multiple sequential axes.
414418
if len(sequential_axes) > 1:
@@ -496,7 +500,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
496500

497501
parallel_count = it.count()
498502
program_ids_template = [
499-
_program_id(next(parallel_count))
503+
_program_id(next(parallel_count), squashed_dims=squashed_dims)
500504
if axis not in sequential_axes
501505
else None
502506
for axis in range(len(logical_grid))
@@ -520,6 +524,7 @@ def make_program_ids(step: ir.Value):
520524
runtime_barriers=grouped_barriers,
521525
name_stack=source_info_util.NameStack(),
522526
traceback_caches=mlir.TracebackCaches(),
527+
squashed_dims=squashed_dims,
523528
)
524529
del runtime_smem, grouped_barriers, runtime_barriers
525530

@@ -911,12 +916,42 @@ def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
911916
raise NotImplementedError("pl.program_id() is not supported in this context")
912917
return ctx.module_ctx.program_ids[axis]
913918

914-
915-
def _program_id(axis: int) -> ir.Value:
916-
return arith_dialect.index_cast(
917-
ir.IntegerType.get_signless(32),
918-
gpu_dialect.block_id(gpu_dialect.Dimension(axis)),
919-
)
919+
def _unravel_program_id(
920+
block_id: ir.Value,
921+
axis: int,
922+
dimensions: tuple[int, ...],
923+
row_major: bool = False
924+
) -> ir.Value:
925+
"""Computes the program ID for axes compressed into one block dimension."""
926+
if row_major:
927+
div_value = math.prod(dimensions[axis+1:])
928+
else:
929+
div_value = math.prod(dimensions[:axis])
930+
div_value = _as_index(_i32_constant(div_value))
931+
pid = arith_dialect.divui(block_id, div_value)
932+
axis_size = _as_index(_i32_constant(dimensions[axis]))
933+
pid = arith_dialect.remui(pid, axis_size)
934+
return arith_dialect.index_cast(ir.IntegerType.get_signless(32), pid)
935+
936+
937+
def _program_id(parallel_axis: int, squashed_dims: tuple[int, ...]) -> ir.Value:
938+
if squashed_dims:
939+
if parallel_axis < len(squashed_dims):
940+
# All squashed dimensions are mapped to Dimension.x.
941+
block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x)
942+
return _unravel_program_id(block_id, parallel_axis, squashed_dims)
943+
else:
944+
# Handle unsquashed axes.
945+
return arith_dialect.index_cast(
946+
ir.IntegerType.get_signless(32),
947+
gpu_dialect.block_id(gpu_dialect.Dimension(
948+
parallel_axis - len(squashed_dims) + 1)),
949+
)
950+
else:
951+
return arith_dialect.index_cast(
952+
ir.IntegerType.get_signless(32),
953+
gpu_dialect.block_id(gpu_dialect.Dimension(parallel_axis)),
954+
)
920955

921956

922957
@register_lowering_rule(primitives.num_programs_p)
@@ -1244,16 +1279,44 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
12441279

12451280
@register_lowering_rule(lax.axis_index_p)
12461281
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
1282+
i32 = ir.IntegerType.get_signless(32)
12471283
grid_names = ctx.module_ctx.grid_mapping.grid_names
1284+
squashed_dims = ctx.module_ctx.squashed_dims
1285+
if squashed_dims:
1286+
unsquashed_names = grid_names[-3:]
1287+
squashed_names = grid_names[:-3]
1288+
else:
1289+
# These are unused but initialized for type checkers.
1290+
unsquashed_names = ()
1291+
squashed_names = ()
12481292
if grid_names and axis_name in grid_names:
12491293
if axis_name == grid_names[-1]:
12501294
return mgpu.warpgroup_idx(sync=True)
12511295
else:
1252-
idx = grid_names.index(axis_name)
1253-
return arith_dialect.index_cast(
1254-
ir.IntegerType.get_signless(32),
1255-
gpu_dialect.block_id(gpu_dialect.Dimension(idx)),
1256-
)
1296+
if squashed_dims:
1297+
if axis_name in unsquashed_names:
1298+
# We add 1 to the index because the first dimension is the
1299+
# squashed dimension.
1300+
# e.g. for the grid (a, b, c, d, wg)
1301+
# squashed = (a, b) Mapped to Dimension.x (0)
1302+
# unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2)
1303+
idx = unsquashed_names.index(axis_name) + 1
1304+
return arith_dialect.index_cast(
1305+
i32,
1306+
gpu_dialect.block_id(gpu_dialect.Dimension(idx)),
1307+
)
1308+
elif axis_name in squashed_names:
1309+
# All squashed dimensions are mapped to Dimension.x.
1310+
block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x)
1311+
axis = squashed_names.index(axis_name)
1312+
return _unravel_program_id(block_id, axis, squashed_dims)
1313+
else:
1314+
if axis_name in grid_names:
1315+
idx = grid_names.index(axis_name)
1316+
return arith_dialect.index_cast(
1317+
i32,
1318+
gpu_dialect.block_id(gpu_dialect.Dimension(idx)),
1319+
)
12571320
raise ValueError(
12581321
"Named axes can only refer to GPUMesh axes in Mosaic GPU kernels"
12591322
)
@@ -1669,10 +1732,14 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value:
16691732

16701733

16711734
def _i32_constant(v: int) -> ir.Value:
1735+
if v < jnp.iinfo(jnp.int32).min or v > jnp.iinfo(jnp.int32).max:
1736+
raise ValueError(f"Integer constant out of range for i32: {v}")
16721737
return arith_dialect.constant(ir.IntegerType.get_signless(32), v)
16731738

16741739

16751740
def _i64_constant(v: int) -> ir.Value:
1741+
if v < jnp.iinfo(jnp.int64).min or v > jnp.iinfo(jnp.int64).max:
1742+
raise ValueError(f"Integer constant out of range for i64: {v}")
16761743
return arith_dialect.constant(ir.IntegerType.get_signless(64), v)
16771744

16781745

tests/pallas/mosaic_gpu_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,30 @@ def kernel(o_ref):
609609
jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32),
610610
)
611611

612+
def test_program_id_in_squashed_grid(self):
613+
# Tests whether a grid with >3 logical dimensions is correctly squashed to
614+
# 3 CUDA grid dimensions.
615+
grid = (2, 3, 4, 5)
616+
@functools.partial(
617+
pl.pallas_call,
618+
in_specs=(),
619+
out_specs=pl.BlockSpec((1,) * len(grid) + (128,), lambda *i: (*i, 0)),
620+
out_shape=jax.ShapeDtypeStruct([*grid, 128], jnp.int32),
621+
grid=grid,
622+
)
623+
def kernel(o_ref):
624+
mult = 1
625+
idx = 0
626+
for axis in range(len(grid)-1, -1, -1):
627+
idx += pl.program_id(axis) * mult
628+
mult *= pl.num_programs(axis)
629+
o_ref[...] = jnp.full(o_ref.shape, idx)
630+
631+
np.testing.assert_array_equal(
632+
kernel()[:, :, :, :, 0],
633+
jnp.arange(math.prod(grid), dtype=jnp.int32).reshape(*grid)
634+
)
635+
612636
def test_program_id_in_block_spec(self):
613637
@functools.partial(
614638
pl.pallas_call,
@@ -1383,6 +1407,41 @@ def kernel():
13831407
f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128)
13841408
)
13851409

1410+
def test_multiple_wg_with_squashed_grid(self):
1411+
# Tests whether a grid with >3 logical dimensions is correctly squashed to
1412+
# 3 CUDA grid dimensions.
1413+
b = 4
1414+
x_dim = 3
1415+
y_dim = 5
1416+
z_dim = 7
1417+
num_threads = 2
1418+
mesh = plgpu.GPUMesh(grid=(b, x_dim, y_dim, z_dim),
1419+
num_threads=num_threads,
1420+
axis_names=("b", "x", "y", "z", "wg"))
1421+
1422+
@jax.jit
1423+
def f():
1424+
@pl.run_state
1425+
def inner(y_ref):
1426+
@pl.core_map(mesh)
1427+
def _():
1428+
b_idx = jax.lax.axis_index("b")
1429+
x_idx = jax.lax.axis_index("x")
1430+
y_idx = jax.lax.axis_index("y")
1431+
z_idx = jax.lax.axis_index("z")
1432+
wg_idx = jax.lax.axis_index("wg")
1433+
bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg"))
1434+
y_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to(
1435+
bxyzw_idx, (128,)
1436+
)
1437+
y_init = jnp.zeros((b, x_dim, y_dim, z_dim, num_threads, 128), np.int32)
1438+
return inner(y_init)
1439+
result = f()[:, :, :, :, :, 0]
1440+
ref = np.arange(b * x_dim * y_dim * z_dim * num_threads).reshape(
1441+
result.shape)
1442+
np.testing.assert_array_equal(result, ref)
1443+
1444+
13861445
def test_cross_wg_barrier(self):
13871446
mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",))
13881447

0 commit comments

Comments
 (0)