5757from jax ._src import mesh as mesh_lib
5858from jax ._src .cloud_tpu_init import running_in_cloud_tpu_vm
5959from jax ._src .interpreters import mlir
60+ from jax ._src .lib import jaxlib_extension_version
6061from jax ._src .numpy .util import promote_dtypes , promote_dtypes_inexact
6162from jax ._src .public_test_util import ( # noqa: F401
6263 _assert_numpy_allclose , _check_dtypes_match , _default_tolerance , _dtype , check_close , check_grads ,
@@ -376,10 +377,13 @@ def device_under_test():
376377
377378def supported_dtypes ():
378379 if device_under_test () == "tpu" :
379- types = {np .bool_ , np .int8 , np .int16 , np .int32 , np .uint8 , np .uint16 ,
380- np .uint32 , _dtypes .bfloat16 , np .float16 , np .float32 , np .complex64 ,
380+ types = {np .bool_ , _dtypes .int4 , np .int8 , np .int16 , np .int32 ,
381+ _dtypes .uint4 , np .uint8 , np .uint16 , np .uint32 ,
382+ _dtypes .bfloat16 , np .float16 , np .float32 , np .complex64 ,
381383 _dtypes .float8_e4m3fn , _dtypes .float8_e4m3b11fnuz ,
382384 _dtypes .float8_e5m2 }
385+ if jaxlib_extension_version < 327 :
386+ types -= {_dtypes .int4 , _dtypes .uint4 }
383387 elif device_under_test () == "gpu" :
384388 types = {np .bool_ , np .int8 , np .int16 , np .int32 , np .int64 ,
385389 np .uint8 , np .uint16 , np .uint32 , np .uint64 ,
@@ -389,10 +393,12 @@ def supported_dtypes():
389393 elif device_under_test () == "METAL" :
390394 types = {np .int32 , np .uint32 , np .float32 }
391395 else :
392- types = {np .bool_ , np .int8 , np .int16 , np .int32 , np .int64 ,
393- np .uint8 , np .uint16 , np .uint32 , np .uint64 ,
396+ types = {np .bool_ , _dtypes . int4 , np .int8 , np .int16 , np .int32 , np .int64 ,
397+ _dtypes . uint4 , np .uint8 , np .uint16 , np .uint32 , np .uint64 ,
394398 _dtypes .bfloat16 , np .float16 , np .float32 , np .float64 ,
395399 np .complex64 , np .complex128 }
400+ if jaxlib_extension_version < 327 :
401+ types -= {_dtypes .int4 , _dtypes .uint4 }
396402 if not config .enable_x64 .value :
397403 types -= {np .uint64 , np .int64 , np .float64 , np .complex128 }
398404 return types
0 commit comments