1717from typing import Dict
1818
1919from ml_dtypes ._ml_dtypes_ext import bfloat16
20+ from ml_dtypes ._ml_dtypes_ext import float8_e4m3
2021from ml_dtypes ._ml_dtypes_ext import float8_e4m3b11fnuz
2122from ml_dtypes ._ml_dtypes_ext import float8_e4m3fn
2223from ml_dtypes ._ml_dtypes_ext import float8_e4m3fnuz
2526import numpy as np
2627
2728_bfloat16_dtype = np .dtype (bfloat16 )
29+ _float8_e4m3_dtype = np .dtype (float8_e4m3 )
2830_float8_e4m3b11fnuz_dtype = np .dtype (float8_e4m3b11fnuz )
2931_float8_e4m3fn_dtype = np .dtype (float8_e4m3fn )
3032_float8_e4m3fnuz_dtype = np .dtype (float8_e4m3fnuz )
@@ -41,6 +43,15 @@ def __init__(self):
4143 self .smallest_subnormal = bfloat16 (smallest_subnormal )
4244
4345
46+ class _Float8E4m3MachArLike :
47+
48+ def __init__ (self ):
49+ smallest_normal = float .fromhex ("0x1p-6" )
50+ self .smallest_normal = float8_e4m3 (smallest_normal )
51+ smallest_subnormal = float .fromhex ("0x1p-9" )
52+ self .smallest_subnormal = float8_e4m3 (smallest_subnormal )
53+
54+
4455class _Float8E4m3b11fnuzMachArLike :
4556
4657 def __init__ (self ):
@@ -135,6 +146,51 @@ def float_to_str(f):
135146 # pylint: enable=protected-access
136147 return obj
137148
149+ @staticmethod
150+ def _float8_e4m3_finfo ():
151+ def float_to_str (f ):
152+ return "%6.2e" % float (f )
153+
154+ tiny = float .fromhex ("0x1p-6" ) # 1/64 min normal
155+ resolution = 0.1
156+ eps = float .fromhex ("0x1p-3" ) # 1/8
157+ epsneg = float .fromhex ("0x1p-4" ) # 1/16
158+ max_ = float .fromhex ("0x1.Ep7" ) # 240 max normal
159+
160+ obj = object .__new__ (np .finfo )
161+ obj .dtype = _float8_e4m3_dtype
162+ obj .bits = 8
163+ obj .eps = float8_e4m3 (eps )
164+ obj .epsneg = float8_e4m3 (epsneg )
165+ obj .machep = - 3
166+ obj .negep = - 4
167+ obj .max = float8_e4m3 (max_ )
168+ obj .min = float8_e4m3 (- max_ )
169+ obj .nexp = 4
170+ obj .nmant = 3
171+ obj .iexp = obj .nexp
172+ obj .maxexp = 8
173+ obj .minexp = - 6
174+ obj .precision = 1
175+ obj .resolution = float8_e4m3 (resolution )
176+ # pylint: disable=protected-access
177+ obj ._machar = _Float8E4m3MachArLike ()
178+ if not hasattr (obj , "tiny" ):
179+ obj .tiny = float8_e4m3 (tiny )
180+ if not hasattr (obj , "smallest_normal" ):
181+ obj .smallest_normal = obj ._machar .smallest_normal
182+ obj .smallest_subnormal = obj ._machar .smallest_subnormal
183+
184+ obj ._str_tiny = float_to_str (tiny )
185+ obj ._str_smallest_normal = float_to_str (tiny )
186+ obj ._str_smallest_subnormal = float_to_str (obj .smallest_subnormal )
187+ obj ._str_max = float_to_str (max_ )
188+ obj ._str_epsneg = float_to_str (epsneg )
189+ obj ._str_eps = float_to_str (eps )
190+ obj ._str_resolution = float_to_str (resolution )
191+ # pylint: enable=protected-access
192+ return obj
193+
138194 @staticmethod
139195 def _float8_e4m3b11fnuz_finfo ():
140196 def float_to_str (f ):
@@ -369,6 +425,14 @@ def __new__(cls, dtype):
369425 if _bfloat16_dtype not in cls ._finfo_cache :
370426 cls ._finfo_cache [_bfloat16_dtype ] = cls ._bfloat16_finfo ()
371427 return cls ._finfo_cache [_bfloat16_dtype ]
428+ if (
429+ isinstance (dtype , str )
430+ and dtype == "float8_e4m3"
431+ or dtype == _float8_e4m3_dtype
432+ ):
433+ if _float8_e4m3_dtype not in cls ._finfo_cache :
434+ cls ._finfo_cache [_float8_e4m3_dtype ] = cls ._float8_e4m3_finfo ()
435+ return cls ._finfo_cache [_float8_e4m3_dtype ]
372436 if (
373437 isinstance (dtype , str )
374438 and dtype == "float8_e4m3b11fnuz"
0 commit comments