Skip to content

Commit 2178ed2

Browse files
petebuGoogle-ML-Automation
authored andcommitted
[pallas] Add more test cases for Triton bitcast_convert_type lowering rule.
PiperOrigin-RevId: 698818103
1 parent 1d2dc17 commit 2178ed2

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tests/pallas/ops_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1941,9 +1941,13 @@ def kernel(x_ref, out_ref):
19411941

19421942
@parameterized.parameters(
19431943
(jnp.float16, jnp.float16), # Noop
1944-
(jnp.int16, jnp.float16),
19451944
(jnp.int16, jnp.bfloat16),
1945+
(jnp.int16, jnp.float16),
1946+
(jnp.uint16, jnp.float16),
19461947
(jnp.float32, jnp.int32),
1948+
(jnp.float32, jnp.uint32),
1949+
(jnp.uint32, jnp.int32),
1950+
(jnp.int32, jnp.uint32),
19471951
)
19481952
def test_bitcast_convert_type(self, in_dtype, out_dtype):
19491953
if jtu.test_device_matches(["tpu"]):

0 commit comments

Comments
 (0)