@@ -237,39 +237,28 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
237237
238238PrimitiveType = _xla .PrimitiveType
239239
240- bfloat16 = ml_dtypes .bfloat16
241- float4_e2m1fn = ml_dtypes .float4_e2m1fn
242- float8_e3m4 = ml_dtypes .float8_e3m4
243- float8_e4m3 = ml_dtypes .float8_e4m3
244- float8_e8m0fnu = ml_dtypes .float8_e8m0fnu
245- float8_e4m3fn = ml_dtypes .float8_e4m3fn
246- float8_e4m3b11fnuz = ml_dtypes .float8_e4m3b11fnuz
247- float8_e4m3fnuz = ml_dtypes .float8_e4m3fnuz
248- float8_e5m2 = ml_dtypes .float8_e5m2
249- float8_e5m2fnuz = ml_dtypes .float8_e5m2fnuz
250-
251240XLA_ELEMENT_TYPE_TO_DTYPE = {
252241 PrimitiveType .PRED : np .dtype ('bool' ),
253- PrimitiveType .S4 : np .dtype (' int4' ),
242+ PrimitiveType .S4 : np .dtype (ml_dtypes . int4 ),
254243 PrimitiveType .S8 : np .dtype ('int8' ),
255244 PrimitiveType .S16 : np .dtype ('int16' ),
256245 PrimitiveType .S32 : np .dtype ('int32' ),
257246 PrimitiveType .S64 : np .dtype ('int64' ),
258- PrimitiveType .U4 : np .dtype (' uint4' ),
247+ PrimitiveType .U4 : np .dtype (ml_dtypes . uint4 ),
259248 PrimitiveType .U8 : np .dtype ('uint8' ),
260249 PrimitiveType .U16 : np .dtype ('uint16' ),
261250 PrimitiveType .U32 : np .dtype ('uint32' ),
262251 PrimitiveType .U64 : np .dtype ('uint64' ),
263- PrimitiveType .F4E2M1FN : np .dtype (float4_e2m1fn ),
264- PrimitiveType .F8E3M4 : np .dtype (float8_e3m4 ),
265- PrimitiveType .F8E4M3 : np .dtype (float8_e4m3 ),
266- PrimitiveType .F8E4M3FN : np .dtype (float8_e4m3fn ),
267- PrimitiveType .F8E4M3B11FNUZ : np .dtype (float8_e4m3b11fnuz ),
268- PrimitiveType .F8E4M3FNUZ : np .dtype (float8_e4m3fnuz ),
269- PrimitiveType .F8E5M2 : np .dtype (float8_e5m2 ),
270- PrimitiveType .F8E5M2FNUZ : np .dtype (float8_e5m2fnuz ),
271- PrimitiveType .F8E8M0FNU : np .dtype (float8_e8m0fnu ),
272- PrimitiveType .BF16 : np .dtype (bfloat16 ),
252+ PrimitiveType .F4E2M1FN : np .dtype (ml_dtypes . float4_e2m1fn ),
253+ PrimitiveType .F8E3M4 : np .dtype (ml_dtypes . float8_e3m4 ),
254+ PrimitiveType .F8E4M3 : np .dtype (ml_dtypes . float8_e4m3 ),
255+ PrimitiveType .F8E4M3FN : np .dtype (ml_dtypes . float8_e4m3fn ),
256+ PrimitiveType .F8E4M3B11FNUZ : np .dtype (ml_dtypes . float8_e4m3b11fnuz ),
257+ PrimitiveType .F8E4M3FNUZ : np .dtype (ml_dtypes . float8_e4m3fnuz ),
258+ PrimitiveType .F8E5M2 : np .dtype (ml_dtypes . float8_e5m2 ),
259+ PrimitiveType .F8E5M2FNUZ : np .dtype (ml_dtypes . float8_e5m2fnuz ),
260+ PrimitiveType .F8E8M0FNU : np .dtype (ml_dtypes . float8_e8m0fnu ),
261+ PrimitiveType .BF16 : np .dtype (ml_dtypes . bfloat16 ),
273262 PrimitiveType .F16 : np .dtype ('float16' ),
274263 PrimitiveType .F32 : np .dtype ('float32' ),
275264 PrimitiveType .F64 : np .dtype ('float64' ),
0 commit comments