Skip to content

Commit bf692ef

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Support direct cast i8 vector to mask
PiperOrigin-RevId: 707617318
1 parent bb16d5a commit bf692ef

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,6 @@ class VectorLayoutInferer {
182182
auto false_ty = dyn_cast<VectorType>(op.getFalseValue().getType());
183183
TPU_CHECK_OP(static_cast<bool>(true_ty) == static_cast<bool>(false_ty),
184184
"Only one side of arith is a vector?");
185-
if (true_ty) {
186-
TPU_CHECK_OP(true_ty.getElementTypeBitWidth() == kNativeBitwidth &&
187-
false_ty.getElementTypeBitWidth() == kNativeBitwidth,
188-
"Only 32-bit select supported");
189-
}
190185
if (inferElementwise(&any_op).failed()) {
191186
return failure();
192187
}

tests/pallas/tpu_ops_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Tests for TPU specific operations within pallas_call."""
1515

16+
import functools
1617
import sys
1718
import unittest
1819

@@ -251,6 +252,34 @@ def body(x_ref, y_ref):
251252
expected = expected.at[...].set(jnp.where(get_mask(x), 0.0, -1.0))
252253
np.testing.assert_array_equal(result, expected)
253254

255+
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int8])
256+
def test_cast_vector_to_mask(self, dtype):
257+
shape = (128, 128)
258+
bitwidth = pallas_utils.dtype_bitwidth(dtype)
259+
if (
260+
(jtu.get_tpu_version() > 5 and bitwidth < 8)
261+
or (jtu.get_tpu_version() == 5 and bitwidth not in (8, 32))
262+
or (jtu.get_tpu_version() < 5 and bitwidth < 32)
263+
):
264+
self.skipTest(
265+
f"Not implemented: cast vector to mask with bitwidth == {bitwidth}"
266+
)
267+
268+
@functools.partial(
269+
pl.pallas_call,
270+
out_shape=jax.ShapeDtypeStruct(shape, dtype),
271+
)
272+
def kernel(x_ref, mask_ref, o_ref):
273+
zeros = jnp.zeros_like(x_ref)
274+
o_ref[...] = jnp.where(mask_ref[...], x_ref[...], zeros)
275+
276+
mask = jax.random.bernoulli(jax.random.key(1234), 0.5, shape).astype(dtype)
277+
x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + 1
278+
279+
out = kernel(x, mask)
280+
expected = jnp.where(mask, x, jnp.zeros_like(x))
281+
self.assertArraysEqual(out, expected)
282+
254283

255284
class OpsInterpretTest(OpsTest):
256285
INTERPRET = True

0 commit comments

Comments
 (0)