@@ -90,19 +90,18 @@ def type(self) -> type: ...
9090
9191
9292# fp8 support
93- # TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
94- float8_e3m4 : type [np .generic ] | None = None
95- float8_e4m3 : type [np .generic ] | None = None
96- float8_e8m0fnu : type [np .generic ] | None = None
93+ float8_e3m4 : type [np .generic ] = ml_dtypes .float8_e3m4
94+ float8_e4m3 : type [np .generic ] = ml_dtypes .float8_e4m3
95+ float8_e8m0fnu : type [np .generic ] = ml_dtypes .float8_e8m0fnu
9796float8_e4m3b11fnuz : type [np .generic ] = ml_dtypes .float8_e4m3b11fnuz
9897float8_e4m3fn : type [np .generic ] = ml_dtypes .float8_e4m3fn
9998float8_e4m3fnuz : type [np .generic ] = ml_dtypes .float8_e4m3fnuz
10099float8_e5m2 : type [np .generic ] = ml_dtypes .float8_e5m2
101100float8_e5m2fnuz : type [np .generic ] = ml_dtypes .float8_e5m2fnuz
102101
103- _float8_e3m4_dtype : np .dtype | None = None
104- _float8_e4m3_dtype : np .dtype | None = None
105- _float8_e8m0fnu_dtype : np .dtype | None = None
102+ _float8_e3m4_dtype : np .dtype = np . dtype ( float8_e3m4 )
103+ _float8_e4m3_dtype : np .dtype = np . dtype ( float8_e4m3 )
104+ _float8_e8m0fnu_dtype : np .dtype = np . dtype ( float8_e8m0fnu )
106105_float8_e4m3b11fnuz_dtype : np .dtype = np .dtype (float8_e4m3b11fnuz )
107106_float8_e4m3fn_dtype : np .dtype = np .dtype (float8_e4m3fn )
108107_float8_e4m3fnuz_dtype : np .dtype = np .dtype (float8_e4m3fnuz )
@@ -111,9 +110,9 @@ def type(self) -> type: ...
111110
112111# fp4 support
113112# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
114- float4_e2m1fn : type [np .generic ] | None = None
113+ float4_e2m1fn : type [np .generic ] = ml_dtypes . float4_e2m1fn
115114
116- _float4_e2m1fn_dtype : np .dtype | None = None
115+ _float4_e2m1fn_dtype : np .dtype = np . dtype ( float4_e2m1fn )
117116
118117def supports_inf (dtype : DTypeLike ) -> bool :
119118 """Return true if the dtype supports infinity, else return False."""
@@ -127,6 +126,10 @@ def supports_inf(dtype: DTypeLike) -> bool:
127126_bfloat16_dtype : np .dtype = np .dtype (bfloat16 )
128127
129128_custom_float_scalar_types = [
129+ float4_e2m1fn ,
130+ float8_e3m4 ,
131+ float8_e4m3 ,
132+ float8_e8m0fnu ,
130133 float8_e4m3b11fnuz ,
131134 float8_e4m3fn ,
132135 float8_e4m3fnuz ,
@@ -135,6 +138,10 @@ def supports_inf(dtype: DTypeLike) -> bool:
135138 bfloat16 ,
136139]
137140_custom_float_dtypes = [
141+ _float4_e2m1fn_dtype ,
142+ _float8_e3m4_dtype ,
143+ _float8_e4m3_dtype ,
144+ _float8_e8m0fnu_dtype ,
138145 _float8_e4m3b11fnuz_dtype ,
139146 _float8_e4m3fn_dtype ,
140147 _float8_e4m3fnuz_dtype ,
@@ -143,65 +150,38 @@ def supports_inf(dtype: DTypeLike) -> bool:
143150 _bfloat16_dtype ,
144151]
145152_float8_dtypes = [
153+ _float8_e3m4_dtype ,
154+ _float8_e4m3_dtype ,
155+ _float8_e8m0fnu_dtype ,
146156 _float8_e4m3b11fnuz_dtype ,
147157 _float8_e4m3fn_dtype ,
148158 _float8_e4m3fnuz_dtype ,
149159 _float8_e5m2_dtype ,
150160 _float8_e5m2fnuz_dtype ,
151161]
152162
153- _float4_dtypes : list [np .dtype ] = []
154-
155- # TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
156- if hasattr (ml_dtypes , "float8_e4m3" ):
157- float8_e4m3 = ml_dtypes .float8_e4m3
158- _float8_e4m3_dtype = np .dtype (float8_e4m3 )
159- _custom_float_scalar_types .insert (0 , float8_e4m3 ) # type: ignore[arg-type]
160- _custom_float_dtypes .insert (0 , _float8_e4m3_dtype )
161- _float8_dtypes .insert (0 , _float8_e4m3_dtype )
162- if hasattr (ml_dtypes , "float8_e3m4" ):
163- float8_e3m4 = ml_dtypes .float8_e3m4
164- _float8_e3m4_dtype = np .dtype (float8_e3m4 )
165- _custom_float_scalar_types .insert (0 , float8_e3m4 ) # type: ignore[arg-type]
166- _custom_float_dtypes .insert (0 , _float8_e3m4_dtype )
167- _float8_dtypes .insert (0 , _float8_e3m4_dtype )
168- if hasattr (ml_dtypes , "float8_e8m0fnu" ):
169- float8_e8m0fnu = ml_dtypes .float8_e8m0fnu
170- _float8_e8m0fnu_dtype = np .dtype (float8_e8m0fnu )
171- _custom_float_scalar_types .insert (0 , float8_e8m0fnu ) # type: ignore[arg-type]
172- _custom_float_dtypes .insert (0 , _float8_e8m0fnu_dtype )
173- _float8_dtypes .insert (0 , _float8_e8m0fnu_dtype )
174- if hasattr (ml_dtypes , "float4_e2m1fn" ):
175- float4_e2m1fn = ml_dtypes .float4_e2m1fn
176- _float4_e2m1fn_dtype = np .dtype (float4_e2m1fn )
177- _custom_float_scalar_types .insert (0 , float4_e2m1fn ) # type: ignore[arg-type]
178- _custom_float_dtypes .insert (0 , _float4_e2m1fn_dtype )
179- _float4_dtypes .insert (0 , _float4_e2m1fn_dtype )
180-
181- # 2-bit integer support
182- int2 : type [np .generic ] | None = None
183- uint2 : type [np .generic ] | None = None
184-
185- _int2_dtype : np .dtype | None = None
186- _uint2_dtype : np .dtype | None = None
187-
188- _intn_dtypes = []
189-
190- # Remove the condition once the minimum ml_dtypes version required by JAX
191- # contains https://github.com/jax-ml/ml_dtypes/pull/154.
192- if hasattr (ml_dtypes , 'int2' ):
193- int2 = ml_dtypes .int2
194- uint2 = ml_dtypes .uint2
195- _int2_dtype = np .dtype (int2 )
196- _uint2_dtype = np .dtype (uint2 )
197- _intn_dtypes .extend ([_int2_dtype , _uint2_dtype ])
163+ _float4_dtypes : list [np .dtype ] = [
164+ _float4_e2m1fn_dtype ,
165+ ]
166+
167+ int2 : type [np .generic ] = ml_dtypes .int2
168+ uint2 : type [np .generic ] = ml_dtypes .uint2
169+
170+ _int2_dtype : np .dtype = np .dtype (int2 )
171+ _uint2_dtype : np .dtype = np .dtype (uint2 )
198172
199173# 4-bit integer support
200174int4 : type [np .generic ] = ml_dtypes .int4
201175uint4 : type [np .generic ] = ml_dtypes .uint4
202176_int4_dtype = np .dtype (int4 )
203177_uint4_dtype = np .dtype (uint4 )
204- _intn_dtypes .extend ([_int4_dtype , _uint4_dtype ])
178+
179+ _intn_dtypes = [
180+ _int2_dtype ,
181+ _uint2_dtype ,
182+ _int4_dtype ,
183+ _uint4_dtype ,
184+ ]
205185
206186# Default types.
207187bool_ = np .bool_
@@ -472,9 +452,9 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
472452 # to the normal scalar type hierarchy.
473453 if a_sctype in _custom_float_scalar_types :
474454 return b_sctype in {a_sctype , np .floating , np .inexact , np .number , np .generic }
475- if ( int2 is not None and a_sctype == int2 ) or a_sctype == int4 :
455+ if a_sctype in [ int2 , int4 ] :
476456 return b_sctype in {a_sctype , np .signedinteger , np .integer , np .number , np .generic }
477- if ( uint2 is not None and a_sctype == uint2 ) or a_sctype == uint4 :
457+ if a_sctype in [ uint2 , uint4 ] :
478458 return b_sctype in {a_sctype , np .unsignedinteger , np .integer , np .number , np .generic }
479459
480460 # Otherwise, fall back to numpy.issubdtype
@@ -491,25 +471,22 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
491471_unsigned_types : list [JAXType ]
492472_int_types : list [JAXType ]
493473_unsigned_types = [
474+ np .dtype (uint2 ),
494475 np .dtype (uint4 ),
495476 np .dtype ('uint8' ),
496477 np .dtype ('uint16' ),
497478 np .dtype ('uint32' ),
498479 np .dtype ('uint64' ),
499480]
500481_signed_types = [
482+ np .dtype (int2 ),
501483 np .dtype (int4 ),
502484 np .dtype ('int8' ),
503485 np .dtype ('int16' ),
504486 np .dtype ('int32' ),
505487 np .dtype ('int64' ),
506488]
507489
508- if _int2_dtype is not None :
509- _signed_types .insert (0 , _int2_dtype )
510- if _uint2_dtype is not None :
511- _unsigned_types .insert (0 , _uint2_dtype )
512-
513490_int_types = _unsigned_types + _signed_types
514491
515492_float_types : list [JAXType ] = [
@@ -622,31 +599,21 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis
622599 This DAG maps each type to its immediately higher type on the lattice.
623600 """
624601 b1 , = _bool_types
625- if _int2_dtype is not None :
626- assert _uint2_dtype is not None
627- _uint2 , uint4 , u1 , u2 , u4 , u8 , _int2 , int4 , i1 , i2 , i4 , i8 = _int_types
628- else :
629- uint4 , u1 , u2 , u4 , u8 , int4 , i1 , i2 , i4 , i8 = _int_types
602+ uint2 , uint4 , u1 , u2 , u4 , u8 , int2 , int4 , i1 , i2 , i4 , i8 = _int_types
630603 * f1_types , bf , f2 , f4 , f8 = _float_types
631604 c4 , c8 = _complex_types
632605 i_ , f_ , c_ = _weak_types
633606 if jax_numpy_dtype_promotion == 'standard' :
634607 out : dict [JAXType , list [JAXType ]]
635608 out = {
636609 b1 : [i_ ],
637- i_ : [u1 , uint4 , i1 , int4 ],
638- uint4 : [], u1 : [i2 , u2 ], u2 : [i4 , u4 ], u4 : [i8 , u8 ], u8 : [f_ ],
639- int4 : [], i1 : [i2 ], i2 : [i4 ], i4 : [i8 ], i8 : [f_ ],
610+ i_ : [u1 , uint2 , uint4 , i1 , int2 , int4 ],
611+ uint2 : [], uint4 : [], u1 : [i2 , u2 ], u2 : [i4 , u4 ], u4 : [i8 , u8 ], u8 : [f_ ],
612+ int2 : [], int4 : [], i1 : [i2 ], i2 : [i4 ], i4 : [i8 ], i8 : [f_ ],
640613 f_ : [* f1_types , bf , f2 , c_ ],
641614 ** {t : [] for t in f1_types }, bf : [f4 ], f2 : [f4 ], f4 : [f8 , c4 ], f8 : [c8 ],
642615 c_ : [c4 ], c4 : [c8 ], c8 : [],
643616 }
644- if _int2_dtype is not None :
645- out [i_ ].append (_int2_dtype )
646- out [_int2_dtype ] = []
647- if _uint2_dtype is not None :
648- out [i_ ].append (_uint2_dtype )
649- out [_uint2_dtype ] = []
650617 return out
651618 elif jax_numpy_dtype_promotion == 'strict' :
652619 return {
0 commit comments