Skip to content

Commit 09fed2f

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Remove reexports of ml_dtypes types from xla_client.py.
These should be used directly from ml_dtypes. PiperOrigin-RevId: 745256523
1 parent 62df2e8 commit 09fed2f

File tree

2 files changed

+12
-33
lines changed

2 files changed

+12
-33
lines changed

jaxlib/xla/xla_client.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -237,39 +237,28 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
237237

238238
PrimitiveType = _xla.PrimitiveType
239239

240-
bfloat16 = ml_dtypes.bfloat16
241-
float4_e2m1fn = ml_dtypes.float4_e2m1fn
242-
float8_e3m4 = ml_dtypes.float8_e3m4
243-
float8_e4m3 = ml_dtypes.float8_e4m3
244-
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
245-
float8_e4m3fn = ml_dtypes.float8_e4m3fn
246-
float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz
247-
float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz
248-
float8_e5m2 = ml_dtypes.float8_e5m2
249-
float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz
250-
251240
XLA_ELEMENT_TYPE_TO_DTYPE = {
252241
PrimitiveType.PRED: np.dtype('bool'),
253-
PrimitiveType.S4: np.dtype('int4'),
242+
PrimitiveType.S4: np.dtype(ml_dtypes.int4),
254243
PrimitiveType.S8: np.dtype('int8'),
255244
PrimitiveType.S16: np.dtype('int16'),
256245
PrimitiveType.S32: np.dtype('int32'),
257246
PrimitiveType.S64: np.dtype('int64'),
258-
PrimitiveType.U4: np.dtype('uint4'),
247+
PrimitiveType.U4: np.dtype(ml_dtypes.uint4),
259248
PrimitiveType.U8: np.dtype('uint8'),
260249
PrimitiveType.U16: np.dtype('uint16'),
261250
PrimitiveType.U32: np.dtype('uint32'),
262251
PrimitiveType.U64: np.dtype('uint64'),
263-
PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn),
264-
PrimitiveType.F8E3M4: np.dtype(float8_e3m4),
265-
PrimitiveType.F8E4M3: np.dtype(float8_e4m3),
266-
PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn),
267-
PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz),
268-
PrimitiveType.F8E4M3FNUZ: np.dtype(float8_e4m3fnuz),
269-
PrimitiveType.F8E5M2: np.dtype(float8_e5m2),
270-
PrimitiveType.F8E5M2FNUZ: np.dtype(float8_e5m2fnuz),
271-
PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu),
272-
PrimitiveType.BF16: np.dtype(bfloat16),
252+
PrimitiveType.F4E2M1FN: np.dtype(ml_dtypes.float4_e2m1fn),
253+
PrimitiveType.F8E3M4: np.dtype(ml_dtypes.float8_e3m4),
254+
PrimitiveType.F8E4M3: np.dtype(ml_dtypes.float8_e4m3),
255+
PrimitiveType.F8E4M3FN: np.dtype(ml_dtypes.float8_e4m3fn),
256+
PrimitiveType.F8E4M3B11FNUZ: np.dtype(ml_dtypes.float8_e4m3b11fnuz),
257+
PrimitiveType.F8E4M3FNUZ: np.dtype(ml_dtypes.float8_e4m3fnuz),
258+
PrimitiveType.F8E5M2: np.dtype(ml_dtypes.float8_e5m2),
259+
PrimitiveType.F8E5M2FNUZ: np.dtype(ml_dtypes.float8_e5m2fnuz),
260+
PrimitiveType.F8E8M0FNU: np.dtype(ml_dtypes.float8_e8m0fnu),
261+
PrimitiveType.BF16: np.dtype(ml_dtypes.bfloat16),
273262
PrimitiveType.F16: np.dtype('float16'),
274263
PrimitiveType.F32: np.dtype('float32'),
275264
PrimitiveType.F64: np.dtype('float64'),

jaxlib/xla/xla_client.pyi

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,6 @@ _ifrt_version: int
6262

6363
mlir_api_version: int
6464

65-
bfloat16: type[numpy.generic]
66-
float4_e2m1fn: type[numpy.generic]
67-
float8_e3m4: type[numpy.generic]
68-
float8_e4m3: type[numpy.generic]
69-
float8_e4m3fn: type[numpy.generic]
70-
float8_e4m3b11fnuz: type[numpy.generic]
71-
float8_e4m3fnuz: type[numpy.generic]
72-
float8_e5m2: type[numpy.generic]
73-
float8_e5m2fnuz: type[numpy.generic]
74-
float8_e8m0fnu: type[numpy.generic]
7565
XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype]
7666

7767
_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]]

0 commit comments

Comments
 (0)