Skip to content

Commit bae6600

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] FragmentedArray.reduce_sum now returns a FragmentedArray
This aligns it with the `reduce` method and also makes it clear that the reduction always produces a scalar. PiperOrigin-RevId: 703494443
1 parent fac1b1a commit bae6600

File tree

3 files changed

+6
-12
lines changed

3 files changed

+6
-12
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,13 +1256,11 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
12561256
[x_aval] = ctx.avals_in
12571257
match x.layout:
12581258
case mgpu.WGStridedFragLayout():
1259-
if axes != (0,):
1260-
raise NotImplementedError("No support for axes other than 0 yet")
1259+
if set(axes) != set(range(x_aval.ndim)):
1260+
raise NotImplementedError("No support for axes yet")
12611261
scratch_ty = jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype)
12621262
with ctx.module_ctx.scratch_view([scratch_ty]) as [scratch]:
1263-
return mgpu.FragmentedArray.splat(
1264-
x.reduce_sum(scratch), (), is_signed=mgpu_utils.is_signed(x_aval.dtype)
1265-
)
1263+
return x.reduce_sum(scratch)
12661264
case mgpu.WGMMA_LAYOUT:
12671265
if axes != (x_aval.ndim - 1,):
12681266
raise NotImplementedError

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@ def upcast_to_bf16(reg, high):
11291129
)
11301130

11311131
# NOTE: scratch can be reused immediately once this function returns.
1132-
def reduce_sum(self, scratch) -> ir.Value:
1132+
def reduce_sum(self, scratch):
11331133
if ir.FloatType.isinstance(self.mlir_dtype):
11341134
op = addf
11351135
elif ir.IntegerType.isinstance(self.mlir_dtype):
@@ -1168,7 +1168,7 @@ def reduce_sum(self, scratch) -> ir.Value:
11681168
utils.warpgroup_barrier()
11691169
result = memref.load(scratch, [zero_index])
11701170
utils.warpgroup_barrier() # Make sure everyone is done using scratch.
1171-
return result
1171+
return FragmentedArray.splat(result, (), is_signed=self.is_signed)
11721172

11731173
def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis):
11741174
if isinstance(op, str):

tests/mosaic/gpu_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,11 +1481,7 @@ def kernel(ctx, src, dst, scratch):
14811481
src = mgpu.FragmentedArray.load_strided(
14821482
src, is_signed=utils.is_signed(dtype)
14831483
)
1484-
acc = mgpu.FragmentedArray.splat(
1485-
src.reduce_sum(scratch),
1486-
(m,),
1487-
is_signed=src.is_signed
1488-
)
1484+
acc = src.reduce_sum(scratch).broadcast((m,))
14891485
acc.store_untiled(dst)
14901486

14911487
in_shape = jax.ShapeDtypeStruct((m, n), dtype)

0 commit comments

Comments
 (0)