Skip to content

Commit f828f2d

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu] Pointwise min
PiperOrigin-RevId: 700175724
1 parent 627debc commit f828f2d

File tree

3 files changed

+60
-16
lines changed

3 files changed

+60
-16
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,10 @@ def to_layout(self, new_layout: FragmentedLayout):
622622
reg, self.shape, new_layout, is_signed=self.is_signed
623623
)
624624

625-
def _pointwise(self, op, *other, output_is_signed: bool | None = None):
625+
def _pointwise(self, op, *other, output_is_signed: bool | None = None, force_no_dispatch=False):
626626
# If our layout is a splat, then we should either dispatch to a non-splat
627627
# layout, or broadcast ourselves to the output shape first.
628-
if isinstance(self.layout, WGSplatFragLayout):
628+
if not force_no_dispatch and isinstance(self.layout, WGSplatFragLayout):
629629
output_shape = self.shape
630630
for i, o in enumerate(other):
631631
if not isinstance(o, FragmentedArray):
@@ -642,7 +642,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
642642
output_shape = np.broadcast_shapes(output_shape, o.shape)
643643
# If we get here then we haven't found any non-splat layout.
644644
return self.broadcast(output_shape)._pointwise(
645-
op, *other, output_is_signed=output_is_signed
645+
op, *other, output_is_signed=output_is_signed, force_no_dispatch=True,
646646
)
647647

648648
other_arrs = []
@@ -884,7 +884,17 @@ def max(self, other):
884884
arith.maxsi if self.is_signed else arith.maxui, other
885885
)
886886
else:
887-
return NotImplemented
887+
return NotImplementedError
888+
889+
def min(self, other):
890+
if ir.FloatType.isinstance(self.mlir_dtype):
891+
return self._pointwise(arith.minimumf, other)
892+
elif ir.IntegerType.isinstance(self.mlir_dtype):
893+
return self._pointwise(
894+
arith.minsi if self.is_signed else arith.minui, other
895+
)
896+
else:
897+
return NotImplementedError
888898

889899
def exp(self, *, approx: bool = False):
890900
if not ir.FloatType.isinstance(self.mlir_dtype):

tests/mosaic/gpu_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,7 @@ class FragmentedArrayTest(TestCase):
12561256
operator.add,
12571257
operator.mul,
12581258
operator.sub,
1259+
(lambda x, y: mgpu.FragmentedArray.min(x, y), np.minimum),
12591260
(lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum),
12601261
),
12611262
dtype=[jnp.float32, jnp.int32, jnp.uint32],
@@ -1285,6 +1286,32 @@ def kernel(ctx, dst, _):
12851286
ref_rhs = scalar_rhs or ref_x
12861287
np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs))
12871288

1289+
def test_minimum_np_compatibility(self):
1290+
one = np.ones((128, 128)).astype(np.float32)
1291+
negz = one * -0.
1292+
posz = one * 0.
1293+
nan = one * np.nan
1294+
expectation = (np.minimum(negz, posz) == negz) & (np.minimum(nan, one) != one)
1295+
assert np.all(expectation), expectation
1296+
1297+
def kernel(ctx, dst, _):
1298+
f32 = ir.F32Type.get()
1299+
splat = lambda i: mgpu.FragmentedArray.splat(c(i, f32), (128, 128))
1300+
negz = splat(-0.)
1301+
posz = splat(0.)
1302+
nan = splat(np.nan)
1303+
one = splat(1.)
1304+
res = (negz.min(posz) == negz) & (one.min(nan) != one) & (nan.min(one) != one)
1305+
i8 = ir.IntegerType.get_signless(8)
1306+
res.astype(i8, is_signed=False).store_untiled(dst)
1307+
1308+
out_shape = jax.ShapeDtypeStruct((128, 128), np.int8)
1309+
result = mgpu.as_gpu_kernel(
1310+
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
1311+
)()
1312+
# astype() uses extsi so i1=True becomes -1
1313+
np.testing.assert_array_equal(result == -1, expectation)
1314+
12881315
@parameterized.product(
12891316
op=[operator.truediv, operator.floordiv, operator.mod],
12901317
dtype=[jnp.float32, jnp.int32, jnp.uint32],

tests/pallas/mosaic_gpu_test.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,25 @@ def kernel(x_ref, o_ref):
8383
x = jnp.arange(256).astype(jnp.float32)
8484
np.testing.assert_allclose(kernel(x), unary(x), rtol=rtol)
8585

86+
@parameterized.named_parameters(
87+
("add", lambda x, y: x + y),
88+
("mul", lambda x, y: x * y),
89+
("div", lambda x, y: x / y),
90+
("min", lambda x, y: jnp.minimum(x, y)),
91+
("max", lambda x, y: jnp.maximum(x, y)),
92+
)
93+
def test_binary_op(self, bop):
94+
@functools.partial(
95+
pl.pallas_call,
96+
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
97+
)
98+
def kernel(x_ref, y_ref, o_ref):
99+
o_ref[...] = bop(x_ref[...], y_ref[...])
100+
101+
x = jnp.arange(256).astype(jnp.float32)
102+
y = x + 1
103+
np.testing.assert_array_equal(kernel(x, y), bop(x, y))
104+
86105
def test_add_first(self):
87106
@functools.partial(
88107
pl.pallas_call,
@@ -111,18 +130,6 @@ def kernel(x_ref, out_ref):
111130
x = jnp.arange(math.prod(shape1)).astype(jnp.float32)
112131
np.testing.assert_array_equal(kernel(x), x.reshape(shape2))
113132

114-
def test_add_xy(self):
115-
@functools.partial(
116-
pl.pallas_call,
117-
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
118-
)
119-
def kernel(x_ref, y_ref, o_ref):
120-
o_ref[...] = x_ref[...] + y_ref[...]
121-
122-
x = jnp.arange(256).astype(jnp.float32)
123-
y = x + 1
124-
np.testing.assert_array_equal(kernel(x, y), x + y)
125-
126133
def test_add_xy_indexed(self):
127134
@functools.partial(
128135
pl.pallas_call,

0 commit comments

Comments
 (0)