Skip to content

Commit 1d2dc17

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu] Pointwise op can handle LHS splats.
PiperOrigin-RevId: 698818035
1 parent b1b1ad6 commit 1d2dc17

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,38 @@ def to_layout(self, new_layout: FragmentedLayout):
623623
)
624624

625625
def _pointwise(self, op, *other, output_is_signed: bool | None = None):
626+
if isinstance(self.layout, WGSplatFragLayout):
627+
# Find either the largest operand or an operand that has a
628+
# concrete layout base the layout computation of that.
629+
widest_idx = None
630+
for i, o in enumerate(other):
631+
if not isinstance(o, FragmentedArray):
632+
continue
633+
elif not isinstance(o.layout, WGSplatFragLayout):
634+
widest_idx = i
635+
break
636+
elif not o.layout.can_broadcast_to(self.layout.shape):
637+
# Note: equal shapes can be broadcast to each other. Using
638+
# the negation we make sure to only consider strictly larger
639+
# shapes so that we don't end up ping ponging between equal
640+
# shapes.
641+
widest_idx = i
642+
643+
if widest_idx is not None:
644+
# We need to retain the order of arguments that the op
645+
# expects.
646+
def _op(wide_o, self_o, *args):
647+
pre_wide = args[:widest_idx - 1]
648+
post_wide = args[widest_idx - 1:]
649+
return op(self_o, *pre_wide, wide_o, *post_wide)
650+
return other[widest_idx]._pointwise(
651+
_op,
652+
self,
653+
*other[:widest_idx],
654+
*other[widest_idx + 1:],
655+
output_is_signed=output_is_signed,
656+
)
657+
626658
other_arrs = []
627659
for o in other:
628660
if not isinstance(o, FragmentedArray):
@@ -642,7 +674,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
642674
o.registers.flat[0],
643675
shape=self.shape,
644676
layout=self.layout,
645-
is_signed=self.is_signed,
677+
is_signed=o.is_signed,
646678
)
647679
else:
648680
if self.layout != o.layout:

tests/mosaic/gpu_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,29 @@ def kernel(ctx, dst, _):
14891489
)()
14901490
np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32))
14911491

1492+
1493+
def test_splat_binary_ops(self):
1494+
def kernel(ctx, src, dst, _):
1495+
f32 = ir.F32Type.get()
1496+
pi_arr = mgpu.FragmentedArray.load_strided(src)
1497+
assert isinstance(pi_arr.layout, mgpu.WGStridedFragLayout)
1498+
pi_scalar = arith.constant(f32, ir.FloatAttr.get(f32, 3.14))
1499+
pi_splat = mgpu.FragmentedArray.splat(pi_scalar, ())
1500+
assert isinstance(pi_splat.layout, mgpu.WGSplatFragLayout)
1501+
pi_arr_sq = pi_arr * pi_splat.broadcast(pi_arr.shape)
1502+
assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout)
1503+
pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq
1504+
assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout)
1505+
(pi_arr_sq + pi_arr_cube).store_untiled(dst)
1506+
1507+
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
1508+
inp = jnp.ones_like(out_shape) * 3.14
1509+
result = mgpu.as_gpu_kernel(
1510+
kernel, (1, 1, 1), (128, 1, 1), inp, out_shape, ()
1511+
)(inp)
1512+
np.testing.assert_allclose(result, np.full((128, 32), 3.14 ** 2 + 3.14 ** 3, np.float32))
1513+
1514+
14921515
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))
14931516
def test_strided_load_store(self, in_shape):
14941517
def kernel(ctx, *args):

0 commit comments

Comments
 (0)