Skip to content

Commit 855829e

Browse files
jburnimGoogle-ML-Automation
authored andcommitted
Add int4, uint4 to test_util.suppported_types
To increase test coverage for these types. PiperOrigin-RevId: 744777880
1 parent 5581e7d commit 855829e

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

jax/_src/test_util.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from jax._src import mesh as mesh_lib
5858
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
5959
from jax._src.interpreters import mlir
60+
from jax._src.lib import jaxlib_extension_version
6061
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
6162
from jax._src.public_test_util import ( # noqa: F401
6263
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
@@ -376,10 +377,13 @@ def device_under_test():
376377

377378
def supported_dtypes():
378379
if device_under_test() == "tpu":
379-
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
380-
np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64,
380+
types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32,
381+
_dtypes.uint4, np.uint8, np.uint16, np.uint32,
382+
_dtypes.bfloat16, np.float16, np.float32, np.complex64,
381383
_dtypes.float8_e4m3fn, _dtypes.float8_e4m3b11fnuz,
382384
_dtypes.float8_e5m2}
385+
if jaxlib_extension_version < 327:
386+
types -= {_dtypes.int4, _dtypes.uint4}
383387
elif device_under_test() == "gpu":
384388
types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
385389
np.uint8, np.uint16, np.uint32, np.uint64,
@@ -389,10 +393,12 @@ def supported_dtypes():
389393
elif device_under_test() == "METAL":
390394
types = {np.int32, np.uint32, np.float32}
391395
else:
392-
types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
393-
np.uint8, np.uint16, np.uint32, np.uint64,
396+
types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32, np.int64,
397+
_dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64,
394398
_dtypes.bfloat16, np.float16, np.float32, np.float64,
395399
np.complex64, np.complex128}
400+
if jaxlib_extension_version < 327:
401+
types -= {_dtypes.int4, _dtypes.uint4}
396402
if not config.enable_x64.value:
397403
types -= {np.uint64, np.int64, np.float64, np.complex128}
398404
return types

jaxlib/xla/py_values.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,10 +684,12 @@ absl::StatusOr<PyArgSignature> PyArgSignatureOfValue(nb::handle arg,
684684
// float64_dt and complex128_dt which are taken care of in previous if
685685
// blocks.
686686
(*p)[dtypes.np_bool.ptr()] = numpy_array_handler;
687+
(*p)[dtypes.np_int4.ptr()] = numpy_array_handler;
687688
(*p)[dtypes.np_int8.ptr()] = numpy_array_handler;
688689
(*p)[dtypes.np_int16.ptr()] = numpy_array_handler;
689690
(*p)[dtypes.np_int32.ptr()] = numpy_array_handler;
690691
(*p)[dtypes.np_int64.ptr()] = np_int_handler;
692+
(*p)[dtypes.np_uint4.ptr()] = numpy_array_handler;
691693
(*p)[dtypes.np_uint8.ptr()] = numpy_array_handler;
692694
(*p)[dtypes.np_uint16.ptr()] = numpy_array_handler;
693695
(*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;

jaxlib/xla/xla_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
# Just an internal arbitrary increasing number to help with backward-compatible
5252
# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version.
53-
_version = 326
53+
_version = 327
5454

5555
# An internal increasing version number for protecting jaxlib code against
5656
# ifrt changes.

0 commit comments

Comments
 (0)