Skip to content

Commit ea69401

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu] Fixed off-by-one issue in pointwise argument shuffling when leading argument is splat.
Also adapted the test to catch a possible regression. The issue appeared in >2 operands. PiperOrigin-RevId: 701306731
1 parent f10d3eb commit ea69401

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
632632
continue
633633
elif not isinstance(o.layout, WGSplatFragLayout):
634634
return o._pointwise(
635-
lambda o, *args: op(*args[:i], o, *args[i:]),
635+
lambda o, this, *args: op(this, *args[:i], o, *args[i:]),
636636
self,
637637
*other[:i],
638638
*other[i + 1 :],

tests/mosaic/gpu_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,14 +1565,14 @@ def kernel(ctx, src, dst, _):
15651565
assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout)
15661566
pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq
15671567
assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout)
1568-
(pi_arr_sq + pi_arr_cube).store_untiled(dst)
1568+
(pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst)
15691569

15701570
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
15711571
inp = jnp.ones_like(out_shape) * 3.14
15721572
result = mgpu.as_gpu_kernel(
15731573
kernel, (1, 1, 1), (128, 1, 1), inp, out_shape, ()
15741574
)(inp)
1575-
np.testing.assert_allclose(result, np.full((128, 32), 3.14 ** 2 + 3.14 ** 3, np.float32))
1575+
np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32))
15761576

15771577

15781578
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))

0 commit comments

Comments
 (0)