Skip to content

Commit 6edc31a

Browse files
Merge pull request jax-ml#27525 from jakevdp:ml-dtypes-cleanup
PiperOrigin-RevId: 741651222
2 parents b3a2c53 + 431c2c0 commit 6edc31a

File tree

12 files changed

+104
-215
lines changed

12 files changed

+104
-215
lines changed

jax/_src/dtypes.py

Lines changed: 43 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,18 @@ def type(self) -> type: ...
9090

9191

9292
# fp8 support
93-
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
94-
float8_e3m4: type[np.generic] | None = None
95-
float8_e4m3: type[np.generic] | None = None
96-
float8_e8m0fnu: type[np.generic] | None = None
93+
float8_e3m4: type[np.generic] = ml_dtypes.float8_e3m4
94+
float8_e4m3: type[np.generic] = ml_dtypes.float8_e4m3
95+
float8_e8m0fnu: type[np.generic] = ml_dtypes.float8_e8m0fnu
9796
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
9897
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
9998
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
10099
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
101100
float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz
102101

103-
_float8_e3m4_dtype: np.dtype | None = None
104-
_float8_e4m3_dtype: np.dtype | None = None
105-
_float8_e8m0fnu_dtype: np.dtype | None = None
102+
_float8_e3m4_dtype: np.dtype = np.dtype(float8_e3m4)
103+
_float8_e4m3_dtype: np.dtype = np.dtype(float8_e4m3)
104+
_float8_e8m0fnu_dtype: np.dtype = np.dtype(float8_e8m0fnu)
106105
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
107106
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
108107
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
@@ -111,9 +110,9 @@ def type(self) -> type: ...
111110

112111
# fp4 support
113112
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
114-
float4_e2m1fn: type[np.generic] | None = None
113+
float4_e2m1fn: type[np.generic] = ml_dtypes.float4_e2m1fn
115114

116-
_float4_e2m1fn_dtype: np.dtype | None = None
115+
_float4_e2m1fn_dtype: np.dtype = np.dtype(float4_e2m1fn)
117116

118117
def supports_inf(dtype: DTypeLike) -> bool:
119118
"""Return true if the dtype supports infinity, else return False."""
@@ -127,6 +126,10 @@ def supports_inf(dtype: DTypeLike) -> bool:
127126
_bfloat16_dtype: np.dtype = np.dtype(bfloat16)
128127

129128
_custom_float_scalar_types = [
129+
float4_e2m1fn,
130+
float8_e3m4,
131+
float8_e4m3,
132+
float8_e8m0fnu,
130133
float8_e4m3b11fnuz,
131134
float8_e4m3fn,
132135
float8_e4m3fnuz,
@@ -135,6 +138,10 @@ def supports_inf(dtype: DTypeLike) -> bool:
135138
bfloat16,
136139
]
137140
_custom_float_dtypes = [
141+
_float4_e2m1fn_dtype,
142+
_float8_e3m4_dtype,
143+
_float8_e4m3_dtype,
144+
_float8_e8m0fnu_dtype,
138145
_float8_e4m3b11fnuz_dtype,
139146
_float8_e4m3fn_dtype,
140147
_float8_e4m3fnuz_dtype,
@@ -143,65 +150,38 @@ def supports_inf(dtype: DTypeLike) -> bool:
143150
_bfloat16_dtype,
144151
]
145152
_float8_dtypes = [
153+
_float8_e3m4_dtype,
154+
_float8_e4m3_dtype,
155+
_float8_e8m0fnu_dtype,
146156
_float8_e4m3b11fnuz_dtype,
147157
_float8_e4m3fn_dtype,
148158
_float8_e4m3fnuz_dtype,
149159
_float8_e5m2_dtype,
150160
_float8_e5m2fnuz_dtype,
151161
]
152162

153-
_float4_dtypes: list[np.dtype] = []
154-
155-
# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
156-
if hasattr(ml_dtypes, "float8_e4m3"):
157-
float8_e4m3 = ml_dtypes.float8_e4m3
158-
_float8_e4m3_dtype = np.dtype(float8_e4m3)
159-
_custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type]
160-
_custom_float_dtypes.insert(0, _float8_e4m3_dtype)
161-
_float8_dtypes.insert(0, _float8_e4m3_dtype)
162-
if hasattr(ml_dtypes, "float8_e3m4"):
163-
float8_e3m4 = ml_dtypes.float8_e3m4
164-
_float8_e3m4_dtype = np.dtype(float8_e3m4)
165-
_custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type]
166-
_custom_float_dtypes.insert(0, _float8_e3m4_dtype)
167-
_float8_dtypes.insert(0, _float8_e3m4_dtype)
168-
if hasattr(ml_dtypes, "float8_e8m0fnu"):
169-
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
170-
_float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu)
171-
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
172-
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
173-
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)
174-
if hasattr(ml_dtypes, "float4_e2m1fn"):
175-
float4_e2m1fn = ml_dtypes.float4_e2m1fn
176-
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
177-
_custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type]
178-
_custom_float_dtypes.insert(0, _float4_e2m1fn_dtype)
179-
_float4_dtypes.insert(0, _float4_e2m1fn_dtype)
180-
181-
# 2-bit integer support
182-
int2: type[np.generic] | None = None
183-
uint2: type[np.generic] | None = None
184-
185-
_int2_dtype: np.dtype | None = None
186-
_uint2_dtype: np.dtype | None = None
187-
188-
_intn_dtypes = []
189-
190-
# Remove the condition once the minimum ml_dtypes version required by JAX
191-
# contains https://github.com/jax-ml/ml_dtypes/pull/154.
192-
if hasattr(ml_dtypes, 'int2'):
193-
int2 = ml_dtypes.int2
194-
uint2 = ml_dtypes.uint2
195-
_int2_dtype = np.dtype(int2)
196-
_uint2_dtype = np.dtype(uint2)
197-
_intn_dtypes.extend([_int2_dtype, _uint2_dtype])
163+
_float4_dtypes: list[np.dtype] = [
164+
_float4_e2m1fn_dtype,
165+
]
166+
167+
int2: type[np.generic] = ml_dtypes.int2
168+
uint2: type[np.generic] = ml_dtypes.uint2
169+
170+
_int2_dtype: np.dtype = np.dtype(int2)
171+
_uint2_dtype: np.dtype = np.dtype(uint2)
198172

199173
# 4-bit integer support
200174
int4: type[np.generic] = ml_dtypes.int4
201175
uint4: type[np.generic] = ml_dtypes.uint4
202176
_int4_dtype = np.dtype(int4)
203177
_uint4_dtype = np.dtype(uint4)
204-
_intn_dtypes.extend([_int4_dtype, _uint4_dtype])
178+
179+
_intn_dtypes = [
180+
_int2_dtype,
181+
_uint2_dtype,
182+
_int4_dtype,
183+
_uint4_dtype,
184+
]
205185

206186
# Default types.
207187
bool_ = np.bool_
@@ -472,9 +452,9 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
472452
# to the normal scalar type hierarchy.
473453
if a_sctype in _custom_float_scalar_types:
474454
return b_sctype in {a_sctype, np.floating, np.inexact, np.number, np.generic}
475-
if (int2 is not None and a_sctype == int2) or a_sctype == int4:
455+
if a_sctype in [int2, int4]:
476456
return b_sctype in {a_sctype, np.signedinteger, np.integer, np.number, np.generic}
477-
if (uint2 is not None and a_sctype == uint2) or a_sctype == uint4:
457+
if a_sctype in [uint2, uint4]:
478458
return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic}
479459

480460
# Otherwise, fall back to numpy.issubdtype
@@ -491,25 +471,22 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
491471
_unsigned_types: list[JAXType]
492472
_int_types: list[JAXType]
493473
_unsigned_types = [
474+
np.dtype(uint2),
494475
np.dtype(uint4),
495476
np.dtype('uint8'),
496477
np.dtype('uint16'),
497478
np.dtype('uint32'),
498479
np.dtype('uint64'),
499480
]
500481
_signed_types = [
482+
np.dtype(int2),
501483
np.dtype(int4),
502484
np.dtype('int8'),
503485
np.dtype('int16'),
504486
np.dtype('int32'),
505487
np.dtype('int64'),
506488
]
507489

508-
if _int2_dtype is not None:
509-
_signed_types.insert(0, _int2_dtype)
510-
if _uint2_dtype is not None:
511-
_unsigned_types.insert(0, _uint2_dtype)
512-
513490
_int_types = _unsigned_types + _signed_types
514491

515492
_float_types: list[JAXType] = [
@@ -622,31 +599,21 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis
622599
This DAG maps each type to its immediately higher type on the lattice.
623600
"""
624601
b1, = _bool_types
625-
if _int2_dtype is not None:
626-
assert _uint2_dtype is not None
627-
_uint2, uint4, u1, u2, u4, u8, _int2, int4, i1, i2, i4, i8 = _int_types
628-
else:
629-
uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types
602+
uint2, uint4, u1, u2, u4, u8, int2, int4, i1, i2, i4, i8 = _int_types
630603
*f1_types, bf, f2, f4, f8 = _float_types
631604
c4, c8 = _complex_types
632605
i_, f_, c_ = _weak_types
633606
if jax_numpy_dtype_promotion == 'standard':
634607
out: dict[JAXType, list[JAXType]]
635608
out = {
636609
b1: [i_],
637-
i_: [u1, uint4, i1, int4],
638-
uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
639-
int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
610+
i_: [u1, uint2, uint4, i1, int2, int4],
611+
uint2: [], uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
612+
int2: [], int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
640613
f_: [*f1_types, bf, f2, c_],
641614
**{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
642615
c_: [c4], c4: [c8], c8: [],
643616
}
644-
if _int2_dtype is not None:
645-
out[i_].append(_int2_dtype)
646-
out[_int2_dtype] = []
647-
if _uint2_dtype is not None:
648-
out[i_].append(_uint2_dtype)
649-
out[_uint2_dtype] = []
650617
return out
651618
elif jax_numpy_dtype_promotion == 'strict':
652619
return {

jax/_src/export/serialization.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -357,16 +357,12 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
357357
dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz,
358358
dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2,
359359
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
360+
dtypes._float8_e3m4_dtype: ser_flatbuf.DType.f8_e3m4,
361+
dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3,
362+
dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu,
363+
dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn,
360364
}
361365

362-
if dtypes._float8_e3m4_dtype is not None:
363-
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
364-
if dtypes._float8_e4m3_dtype is not None:
365-
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
366-
if dtypes._float8_e8m0fnu_dtype is not None:
367-
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
368-
if dtypes._float4_e2m1fn_dtype is not None:
369-
_dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn
370366
_dtype_kind_to_dtype = {
371367
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
372368
}

jax/_src/interpreters/mlir.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,24 +185,14 @@ def _is_ir_values(x: IrValues) -> bool:
185185
np.dtype(np.float64): ir.F64Type.get,
186186
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
187187
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
188+
np.dtype(dtypes.int2): partial(ir.IntegerType.get_signless, 2),
189+
np.dtype(dtypes.uint2): partial(ir.IntegerType.get_unsigned, 2),
190+
np.dtype(dtypes.float8_e3m4): ir.Float8E3M4Type.get,
191+
np.dtype(dtypes.float8_e4m3): ir.Float8E4M3Type.get,
192+
np.dtype(dtypes.float8_e8m0fnu): ir.Float8E8M0FNUType.get,
193+
np.dtype(dtypes.float4_e2m1fn): ir.Float4E2M1FNType.get,
188194
}
189195

190-
191-
if dtypes.int2 is not None:
192-
assert dtypes.uint2 is not None
193-
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2)
194-
_dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2)
195-
196-
if dtypes.float8_e3m4 is not None:
197-
_dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get
198-
if dtypes.float8_e4m3 is not None:
199-
_dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get
200-
if dtypes.float8_e8m0fnu is not None:
201-
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get
202-
203-
if dtypes.float4_e2m1fn is not None:
204-
_dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get
205-
206196
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
207197
if isinstance(dtype, core.bint):
208198
# TODO Support different-size underlying dtypes to take advantage of the

jax/_src/lax/lax.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2346,13 +2346,10 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
23462346
np.dtype(dtypes.float8_e4m3fnuz),
23472347
np.dtype(dtypes.float8_e5m2),
23482348
np.dtype(dtypes.float8_e5m2fnuz),
2349+
np.dtype(dtypes.float8_e3m4),
2350+
np.dtype(dtypes.float8_e4m3),
2351+
np.dtype(dtypes.float8_e8m0fnu),
23492352
]
2350-
if dtypes.float8_e3m4 is not None:
2351-
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
2352-
if dtypes.float8_e4m3 is not None:
2353-
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
2354-
if dtypes.float8_e8m0fnu is not None:
2355-
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
23562353
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
23572354
raise ValueError(
23582355
f"The dot algorithm '{self}' requires both inputs to have float8 "
@@ -5602,13 +5599,9 @@ def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr:
56025599
def _handle_dot_precision(ctx, lhs, rhs, precision, platform):
56035600
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
56045601
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
5605-
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
5606-
if dtypes.float8_e3m4 is not None:
5607-
fp8_dtypes += (dtypes.float8_e3m4,)
5608-
if dtypes.float8_e4m3 is not None:
5609-
fp8_dtypes += (dtypes.float8_e4m3,)
5610-
if dtypes.float8_e8m0fnu is not None:
5611-
fp8_dtypes += (dtypes.float8_e8m0fnu,)
5602+
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz,
5603+
dtypes.float8_e3m4, dtypes.float8_e4m3,
5604+
dtypes.float8_e8m0fnu)
56125605
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
56135606

56145607
# The *_ lets us reuse this for ragged_dot_general, which has group_sizes.

jax/_src/numpy/scalar_types.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,33 +68,27 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
6868
return meta
6969

7070
bool_ = _make_scalar_type(np.bool_)
71-
if dtypes.uint2 is not None:
72-
uint2 = _make_scalar_type(dtypes.uint2)
71+
uint2 = _make_scalar_type(dtypes.uint2)
7372
uint4 = _make_scalar_type(dtypes.uint4)
7473
uint8 = _make_scalar_type(np.uint8)
7574
uint16 = _make_scalar_type(np.uint16)
7675
uint32 = _make_scalar_type(np.uint32)
7776
uint64 = _make_scalar_type(np.uint64)
78-
if dtypes.int2 is not None:
79-
int2 = _make_scalar_type(dtypes.int2)
77+
int2 = _make_scalar_type(dtypes.int2)
8078
int4 = _make_scalar_type(dtypes.int4)
8179
int8 = _make_scalar_type(np.int8)
8280
int16 = _make_scalar_type(np.int16)
8381
int32 = _make_scalar_type(np.int32)
8482
int64 = _make_scalar_type(np.int64)
85-
if dtypes.float8_e3m4 is not None:
86-
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
87-
if dtypes.float8_e4m3 is not None:
88-
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
89-
if dtypes.float8_e8m0fnu is not None:
90-
float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu)
83+
float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn)
84+
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
85+
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
86+
float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu)
9187
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
9288
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
9389
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
9490
float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz)
9591
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
96-
if dtypes.float4_e2m1fn is not None:
97-
float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn)
9892
bfloat16 = _make_scalar_type(dtypes.bfloat16)
9993
float16 = _make_scalar_type(np.float16)
10094
float32 = single = _make_scalar_type(np.float32)

0 commit comments

Comments
 (0)