Skip to content

Commit f442d40

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic_gpu] Fixed FragmentedArray comparisons with literals
PiperOrigin-RevId: 698343858
1 parent c76e5fe commit f442d40

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tests/mosaic/gpu_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,18 +1318,21 @@ def kernel(ctx, dst, _):
13181318
operator.ne,
13191319
],
13201320
dtype=[jnp.float32, jnp.int32, jnp.uint32],
1321+
rhs_is_literal=[False, True]
13211322
)
1322-
def test_comparison(self, op, dtype, m=64, n=32):
1323+
def test_comparison(self, op, dtype, rhs_is_literal, m=64, n=32):
13231324
def kernel(ctx, dst, _):
13241325
iota = iota_tensor(m, n, dtype)
1325-
op(iota, iota + 1).store_untiled(dst)
1326+
rhs = 0 if rhs_is_literal else iota + 1
1327+
op(iota, rhs).store_untiled(dst)
13261328

13271329
out_shape = jax.ShapeDtypeStruct((m, n), jnp.bool)
13281330
result = mgpu.as_gpu_kernel(
13291331
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
13301332
)()
13311333
iota = np.arange(m * n, dtype=dtype).reshape(m, n)
1332-
np.testing.assert_array_equal(result, op(iota, iota + 1))
1334+
rhs = rhs = 0 if rhs_is_literal else iota + 1
1335+
np.testing.assert_array_equal(result, op(iota, rhs))
13331336

13341337
@parameterized.product(
13351338
op=[operator.and_, operator.or_, operator.xor],

0 commit comments

Comments
 (0)