Skip to content

Commit b801539

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas][Mosaic GPU] Add support for compressing squeezed dims in async_copy + grid fixes
This change removes the need to flatten the batch dimension into sequence dimensions in the flash attention kernel. The critical thing here is the observation that we can in fact collapse all squeezed dimension into a single one in the TMA descriptor, letting us reduce its rank when necessary. Doing this also uncovered some issues with how we were handling the grid in Pallas:MGPU lowering, which I've fixed. PiperOrigin-RevId: 701035277
1 parent d5bfafb commit b801539

File tree

4 files changed

+110
-44
lines changed

4 files changed

+110
-44
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,6 @@ def lower_jaxpr_to_module(
360360

361361
assert len(jaxpr.outvars) == 0
362362
assert not grid_mapping.vmapped_dims
363-
if len(grid_mapping.grid) > 3:
364-
raise NotImplementedError(
365-
"Only <=3D grids are supported in Mosaic GPU lowering."
366-
)
367363
if grid_mapping.num_dynamic_grid_bounds:
368364
raise NotImplementedError(
369365
"Dynamic grid bounds not supported in the Mosaic GPU lowering."
@@ -397,16 +393,19 @@ def lower_jaxpr_to_module(
397393
f" {max_concurrent_steps=}, {delay_release=}"
398394
)
399395

400-
block = (128, 1, 1)
401-
grid = grid_mapping.grid
402396
if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count
403397
block = (128 * grid_mapping.grid[-1], 1, 1)
404-
grid = grid[:-1]
405-
406-
grid = [d for i, d in enumerate(grid) if i not in sequential_axes]
407-
if len(grid) < 3:
408-
grid += (1,) * (3 - len(grid))
398+
logical_grid = grid_mapping.grid[:-1]
409399
else:
400+
block = (128, 1, 1)
401+
logical_grid = grid_mapping.grid
402+
403+
parallel_grid = [
404+
d for i, d in enumerate(logical_grid) if i not in sequential_axes
405+
]
406+
if len(parallel_grid) < 3:
407+
parallel_grid += (1,) * (3 - len(parallel_grid))
408+
elif len(parallel_grid) > 3:
410409
raise NotImplementedError(
411410
"Only <=3D grids are supported in Mosaic GPU lowering."
412411
)
@@ -500,7 +499,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
500499
_program_id(next(parallel_count))
501500
if axis not in sequential_axes
502501
else None
503-
for axis in range(len(grid_mapping.grid))
502+
for axis in range(len(logical_grid))
504503
]
505504

506505
def make_program_ids(step: ir.Value):
@@ -788,7 +787,7 @@ def _(step, carry):
788787
prof_ctx = ProfilerContext(params["profile_dir"], prof_spec)
789788
module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel(
790789
body,
791-
grid=grid,
790+
grid=parallel_grid,
792791
cluster=(),
793792
block=block,
794793
in_shapes=in_structs_gmem,
@@ -806,7 +805,9 @@ def _(step, carry):
806805
prof_spec=prof_spec,
807806
)
808807

809-
return LoweringResult(module, grid, block, out_structs_gmem, prof_ctx)
808+
return LoweringResult(
809+
module, parallel_grid, block, out_structs_gmem, prof_ctx
810+
)
810811

811812

812813
mosaic_lowering_rules = {}

jax/experimental/mosaic/gpu/core.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,57 @@ def batch(self, leading_rank: int) -> MemRefTransform:
234234
)
235235

236236

237+
@dataclasses.dataclass(frozen=True)
238+
class CollapseLeadingIndicesTransform(MemRefTransform):
239+
"""Collapses leading indices into one."""
240+
strides: tuple[int, ...]
241+
242+
@functools.cached_property
243+
def common_stride(self) -> int:
244+
return math.gcd(*self.strides)
245+
246+
def apply(self, ref: ir.Value) -> ir.Value:
247+
ref_ty = ir.MemRefType(ref.type)
248+
strides, offset = ref_ty.get_strides_and_offset()
249+
if offset == ir.ShapedType.get_dynamic_stride_or_offset():
250+
raise NotImplementedError("Dynamic offsets are not supported")
251+
max_bound = sum(
252+
(d - 1) * s // self.common_stride
253+
for d, s in zip(
254+
ref_ty.shape[: len(self.strides)], strides[: len(self.strides)]
255+
)
256+
) + 1
257+
new_shape = [max_bound, *ref_ty.shape[len(self.strides):]]
258+
new_strides = [self.common_stride, *strides[len(self.strides):]]
259+
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
260+
new_ref_ty = ir.MemRefType.get(
261+
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
262+
)
263+
return memref.reinterpret_cast(
264+
new_ref_ty, ref, [], [], [],
265+
static_offsets=[offset],
266+
static_sizes=new_shape,
267+
static_strides=new_strides,
268+
)
269+
270+
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
271+
index = ir.IndexType.get()
272+
flat_idx = c(0, index)
273+
for i, s in zip(idx[:len(self.strides)], self.strides):
274+
flat_idx = arith.addi(
275+
flat_idx, arith.muli(i, c(s // self.common_stride, index))
276+
)
277+
return (flat_idx, *idx[len(self.strides):])
278+
279+
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
280+
if any(s != 1 for s in shape[:len(self.strides)]):
281+
raise ValueError("Expected leading indices to be squeezed")
282+
return (1, *shape[len(self.strides):])
283+
284+
def batch(self, leading_rank: int) -> MemRefTransform:
285+
raise NotImplementedError # Unused
286+
287+
237288
OnDeviceProfiler = profiler.OnDeviceProfiler
238289

239290

@@ -397,6 +448,17 @@ def async_copy(
397448
or gmem_ref.owner.opview.OPERATION_NAME != expected_name
398449
):
399450
raise ValueError("GMEM reference in async_copy must be a kernel argument")
451+
gmem_ref_ty = ir.MemRefType(gmem_ref.type)
452+
gmem_strides, _ = gmem_ref_ty.get_strides_and_offset()
453+
if gmem_strides != utils.get_contiguous_strides(gmem_ref_ty.shape):
454+
raise NotImplementedError(
455+
"async_copy assumes the GMEM reference is contiguous"
456+
)
457+
if any(s * element_bytewidth % 16 != 0 for s in gmem_strides[:-1]):
458+
raise ValueError(
459+
"async_copy requires all GMEM strides except the last one to be a"
460+
" multiple of 16 bytes"
461+
)
400462

401463
base_indices, slice_shape, is_squeezed = utils.parse_indices(
402464
gmem_slice, ir.MemRefType(gmem_ref.type).shape
@@ -421,9 +483,25 @@ def async_copy(
421483
dyn_base_indices = t.transform_index(dyn_base_indices)
422484
slice_shape = t.transform_shape(slice_shape)
423485

486+
num_squeezed_dims = len(squeezed_dims)
487+
if len(slice_shape) > 5:
488+
# We can try to collapse all squeezed dims into one.
489+
if len(slice_shape) - num_squeezed_dims + 1 > 5:
490+
raise ValueError(
491+
"Async copies only support striding up to 5 dimensions"
492+
)
493+
collapse = CollapseLeadingIndicesTransform(
494+
tuple(gmem_strides[d] for d in squeezed_dims)
495+
)
496+
gmem_transform = (*gmem_transform, collapse)
497+
dyn_base_indices = collapse.transform_index(dyn_base_indices)
498+
slice_shape = collapse.transform_shape(slice_shape)
499+
num_squeezed_dims = 1
500+
del squeezed_dims, sliced_dims # Those no longer make sense.
501+
424502
smem_ref_ty = ir.MemRefType(smem_ref.type)
425503
# We moved all squeezed dims to the front.
426-
if slice_shape[len(squeezed_dims):] != tuple(smem_ref_ty.shape):
504+
if slice_shape[num_squeezed_dims:] != tuple(smem_ref_ty.shape):
427505
raise ValueError(
428506
"Expected the SMEM reference to have the same shape as the"
429507
f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
@@ -437,7 +515,7 @@ def async_copy(
437515

438516
dyn_base_indices = list(dyn_base_indices)
439517
slice_shape = list(slice_shape)
440-
assert all(d == 1 for d in slice_shape[:len(squeezed_dims)])
518+
assert all(d == 1 for d in slice_shape[:num_squeezed_dims])
441519
collective_size = 1
442520
if collective is not None:
443521
if isinstance(collective, gpu.Dimension):
@@ -446,14 +524,14 @@ def async_copy(
446524
if collective_size > 1:
447525
def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
448526
# No need to partition squeezed dims. They don't even exist in smem_ref.
449-
assert dim >= len(squeezed_dims)
527+
assert dim >= num_squeezed_dims
450528
nonlocal smem_ref
451529
slice_shape[dim] //= num_chunks
452530
block_offset = arith.muli(idx, c(slice_shape[dim], index))
453531
dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset)
454532
smem_ref = utils.memref_slice(
455533
smem_ref,
456-
(slice(None),) * (dim - len(squeezed_dims))
534+
(slice(None),) * (dim - num_squeezed_dims)
457535
+ (utils.ds(block_offset, slice_shape[dim]),),
458536
)
459537
stride = 1
@@ -508,9 +586,6 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
508586
else contextlib.nullcontext
509587
)
510588

511-
rank = len(slice_shape)
512-
if rank > 5: # TODO: apaszke - Implement stride compression
513-
raise ValueError("Async copies only support striding up to 5 dimensions")
514589
if max(slice_shape) > 256:
515590
raise ValueError(
516591
"Async copies only support copying <=256 elements along each"

jax/experimental/pallas/ops/gpu/attention_mgpu.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,25 +61,14 @@ def attention(q, k, v, config: TuningConfig):
6161
raise ValueError(f"{head_dim=} must be divisible by 64")
6262
if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]):
6363
raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}")
64-
# Squash batch and sequence dimensions.
65-
# This is required because CUDA grid/TMA descriptors have a limited number of
66-
# slice dimensions.
67-
# TODO(apaszke): Implement slice squashing for TMAs.
68-
q = jnp.reshape(q, (batch_size * q_seq_len, num_q_heads, head_dim))
69-
k = jnp.reshape(k, (batch_size * kv_seq_len, num_kv_heads, head_dim))
70-
v = jnp.reshape(v, (batch_size * kv_seq_len, num_kv_heads, head_dim))
7164

7265
max_concurrent_steps = min(
7366
config.max_concurrent_steps, kv_seq_len // config.block_kv
7467
)
7568
block_q, block_kv = config.block_q, config.block_kv
76-
num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
77-
if rem:
78-
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")
7969

8070
def kernel(q_ref, k_ref, v_ref, out_ref, scoped):
81-
bidx = lax.div(lax.axis_index("bq"), num_q_tiles)
82-
qidx = lax.rem(lax.axis_index("bq"), num_q_tiles)
71+
batch = lax.axis_index("batch")
8372
smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped
8473
wg_idx = lax.axis_index("wg")
8574
qo_smem2, k_smem, v_smem = smem_buffers
@@ -93,11 +82,11 @@ def perform_schedule_barrier():
9382
def _compute_wg():
9483
plgpu.set_max_registers(232, action="increase")
9584
qo_smem = qo_smem2.at[wg_idx]
96-
q_seq_base = qidx * (2 * block_q) + wg_idx * block_q + bidx * q_seq_len
85+
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
9786
q_head = lax.axis_index("heads")
9887

9988
plgpu.copy_gmem_to_smem(
100-
q_ref.at[pl.ds(q_seq_base, block_q), q_head],
89+
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
10190
qo_smem,
10291
q_barriers.at[wg_idx],
10392
)
@@ -167,24 +156,22 @@ def _wait():
167156
qo_smem[...] = acc.astype(dtype)
168157
plgpu.commit_smem()
169158
plgpu.copy_smem_to_gmem(
170-
qo_smem, out_ref.at[pl.ds(q_seq_base, block_q), q_head],
159+
qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
171160
)
172161
plgpu.wait_smem_to_gmem(0)
173162
@pl.when(wg_idx == 2)
174163
def _memory_wg():
175164
plgpu.set_max_registers(40, action="decrease")
176165
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
177166
for i in range(max_concurrent_steps):
178-
start = i * block_kv + bidx * kv_seq_len
179-
s = (pl.ds(start, block_kv), kv_head)
167+
s = (batch, pl.ds(i * block_kv, block_kv), kv_head)
180168
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i])
181169
plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i])
182170

183171
def kv_loop(kv_step, _):
184172
tma_step = kv_step + max_concurrent_steps
185173
tma_slot = lax.rem(kv_step, max_concurrent_steps)
186-
start = tma_step * block_kv + bidx * kv_seq_len
187-
s = (pl.ds(start, block_kv), kv_head)
174+
s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head)
188175
plgpu.barrier_wait(k_consumed_barrier)
189176
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot])
190177
plgpu.barrier_wait(v_consumed_barrier)
@@ -199,10 +186,13 @@ def kv_epilogue(i, _):
199186
def run(refs):
200187
q_ref, k_ref, v_ref, out_ref = refs
201188

189+
num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
190+
if rem:
191+
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")
202192
mesh = plgpu.GPUMesh(
203-
grid=(batch_size * num_q_tiles, num_q_heads),
193+
grid=(batch_size, num_q_tiles, num_q_heads),
204194
num_threads=3,
205-
axis_names=("bq", "heads", "wg"),
195+
axis_names=("batch", "q_seq", "heads", "wg"),
206196
approx_math=True,
207197
)
208198
@pl.core_map(mesh)
@@ -236,7 +226,7 @@ def _kernel_entry():
236226
)
237227

238228
_, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf)))
239-
return jnp.reshape(out, [batch_size, q_seq_len, num_q_heads, head_dim])
229+
return out
240230

241231

242232
@jax.jit

tests/mosaic/gpu_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1240,7 +1240,7 @@ def run_kernel(shape):
12401240
x = np.arange(np.prod(shape)).reshape(shape)
12411241
_ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
12421242

1243-
with self.assertRaisesRegex(ValueError, "only support striding up to 5"):
1243+
with self.assertRaisesRegex(ValueError, "all GMEM strides except the last"):
12441244
run_kernel([1] * 6)
12451245

12461246
with self.assertRaisesRegex(

0 commit comments

Comments
 (0)