Skip to content

Commit 65d2ca6

Browse files
committed
jax.lax: raise TypeError for mismatched dtypes
1 parent 5fe8bcc commit 65d2ca6

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

jax/_src/lax/lax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2209,7 +2209,7 @@ def broadcasting_sharding_rule(name, *avals):
22092209

22102210

22112211
def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False,
2212-
require_same_dtypes=False):
2212+
require_same_dtypes=True):
22132213
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name,
22142214
allow_extended_dtype=allow_extended_dtype,
22152215
require_same=require_same_dtypes)

tests/lax_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3366,6 +3366,12 @@ def test_ops_do_not_accept_complex_dtypes(self, op):
33663366
with self.assertRaisesRegex(TypeError, ".*does not accept dtype complex.*"):
33673367
op(2+3j, 4+5j)
33683368

3369+
@parameterized.parameters([lax.add, lax.mul, lax.div, lax.rem, lax.lt, lax.gt,
3370+
lax.ge, lax.le, lax.eq, lax.ne])
3371+
def test_ops_error_on_mismatched_dtypes(self, op):
3372+
with self.assertRaisesRegex(TypeError, ".*requires arguments to have the same dtypes.*"):
3373+
op(0, 0.0)
3374+
33693375
def test_population_count_booleans_not_supported(self):
33703376
# https://github.com/jax-ml/jax/issues/3886
33713377
msg = "population_count does not accept dtype bool"

0 commit comments

Comments
 (0)