22use crate :: backend:: { BackendDevice , BackendStorage } ;
33use crate :: op:: { BinaryOpT , CmpOp , ReduceOp , UnaryOpT } ;
44use crate :: { DType , Error , IntDType , Layout , Result , Shape , WithDType } ;
5+ use float8:: F8E4M3 ;
56use half:: { bf16, f16} ;
67use rayon:: prelude:: * ;
78
@@ -25,6 +26,7 @@ pub enum CpuStorage {
2526 F16 ( Vec < f16 > ) ,
2627 F32 ( Vec < f32 > ) ,
2728 F64 ( Vec < f64 > ) ,
29+ F8E4M3 ( Vec < F8E4M3 > ) ,
2830}
2931
3032#[ derive( Debug , Clone ) ]
@@ -36,6 +38,7 @@ pub enum CpuStorageRef<'a> {
3638 F16 ( & ' a [ f16 ] ) ,
3739 F32 ( & ' a [ f32 ] ) ,
3840 F64 ( & ' a [ f64 ] ) ,
41+ F8E4M3 ( & ' a [ F8E4M3 ] ) ,
3942}
4043
4144#[ derive( Debug , Clone ) ]
@@ -1691,6 +1694,17 @@ impl CpuStorage {
16911694 . concat ( ) ;
16921695 Self :: F64 ( storages)
16931696 }
1697+ Self :: F8E4M3 ( _) => {
1698+ let storages = storages
1699+ . iter ( )
1700+ . map ( |s| match s {
1701+ Self :: F8E4M3 ( s) => Ok ( s. as_slice ( ) ) ,
1702+ _ => crate :: bail!( "dtype mismatch" ) ,
1703+ } )
1704+ . collect :: < Result < Vec < _ > > > ( ) ?
1705+ . concat ( ) ;
1706+ Self :: F8E4M3 ( storages)
1707+ }
16941708 } ;
16951709 Ok ( s)
16961710 }
@@ -1708,6 +1722,7 @@ impl BackendStorage for CpuStorage {
17081722 Self :: F16 ( _) => DType :: F16 ,
17091723 Self :: F32 ( _) => DType :: F32 ,
17101724 Self :: F64 ( _) => DType :: F64 ,
1725+ Self :: F8E4M3 ( _) => DType :: F8E4M3 ,
17111726 }
17121727 }
17131728
@@ -1742,6 +1757,10 @@ impl BackendStorage for CpuStorage {
17421757 let data = unary_map ( storage, layout, bf16:: from_f64) ;
17431758 Ok ( Self :: BF16 ( data) )
17441759 }
1760+ ( Self :: F8E4M3 ( storage) , DType :: BF16 ) => {
1761+ let data = unary_map ( storage, layout, |v| bf16:: from_f32 ( v. to_f32 ( ) ) ) ;
1762+ Ok ( Self :: BF16 ( data) )
1763+ }
17451764 ( Self :: U8 ( storage) , DType :: F16 ) => {
17461765 let data = unary_map ( storage, layout, |v| f16:: from_f32 ( v as f32 ) ) ;
17471766 Ok ( Self :: F16 ( data) )
@@ -1770,6 +1789,10 @@ impl BackendStorage for CpuStorage {
17701789 let data = unary_map ( storage, layout, f16:: from_f64) ;
17711790 Ok ( Self :: F16 ( data) )
17721791 }
1792+ ( Self :: F8E4M3 ( storage) , DType :: F16 ) => {
1793+ let data = unary_map ( storage, layout, |v| f16:: from_f32 ( v. to_f32 ( ) ) ) ;
1794+ Ok ( Self :: F16 ( data) )
1795+ }
17731796 ( Self :: U8 ( storage) , DType :: F32 ) => {
17741797 let data = unary_map ( storage, layout, |v| v as f32 ) ;
17751798 Ok ( Self :: F32 ( data) )
@@ -1798,6 +1821,10 @@ impl BackendStorage for CpuStorage {
17981821 let data = unary_map ( storage, layout, |v| v as f32 ) ;
17991822 Ok ( Self :: F32 ( data) )
18001823 }
1824+ ( Self :: F8E4M3 ( storage) , DType :: F32 ) => {
1825+ let data = unary_map ( storage, layout, |v| v. to_f32 ( ) ) ;
1826+ Ok ( Self :: F32 ( data) )
1827+ }
18011828 ( Self :: U8 ( storage) , DType :: U8 ) => {
18021829 let data = unary_map ( storage, layout, |v| v) ;
18031830 Ok ( Self :: U8 ( data) )
@@ -1826,6 +1853,10 @@ impl BackendStorage for CpuStorage {
18261853 let data = unary_map ( storage, layout, |v| v as u8 ) ;
18271854 Ok ( Self :: U8 ( data) )
18281855 }
1856+ ( Self :: F8E4M3 ( storage) , DType :: U8 ) => {
1857+ let data = unary_map ( storage, layout, |v| v. to_f32 ( ) as u8 ) ;
1858+ Ok ( Self :: U8 ( data) )
1859+ }
18291860 ( Self :: U8 ( storage) , DType :: U32 ) => {
18301861 let data = unary_map ( storage, layout, |v| v as u32 ) ;
18311862 Ok ( Self :: U32 ( data) )
@@ -1854,6 +1885,10 @@ impl BackendStorage for CpuStorage {
18541885 let data = unary_map ( storage, layout, |v| v as u32 ) ;
18551886 Ok ( Self :: U32 ( data) )
18561887 }
1888+ ( Self :: F8E4M3 ( storage) , DType :: U32 ) => {
1889+ let data = unary_map ( storage, layout, |v| v. to_f32 ( ) as u32 ) ;
1890+ Ok ( Self :: U32 ( data) )
1891+ }
18571892 ( Self :: U8 ( storage) , DType :: I64 ) => {
18581893 let data = unary_map ( storage, layout, |v| v as i64 ) ;
18591894 Ok ( Self :: I64 ( data) )
@@ -1882,6 +1917,10 @@ impl BackendStorage for CpuStorage {
18821917 let data = unary_map ( storage, layout, |v| v as i64 ) ;
18831918 Ok ( Self :: I64 ( data) )
18841919 }
1920+ ( Self :: F8E4M3 ( storage) , DType :: I64 ) => {
1921+ let data = unary_map ( storage, layout, |v| v. to_f32 ( ) as i64 ) ;
1922+ Ok ( Self :: I64 ( data) )
1923+ }
18851924 ( Self :: U8 ( storage) , DType :: F64 ) => {
18861925 let data = unary_map ( storage, layout, |v| v as f64 ) ;
18871926 Ok ( Self :: F64 ( data) )
@@ -1910,6 +1949,42 @@ impl BackendStorage for CpuStorage {
19101949 let data = unary_map ( storage, layout, |v| v) ;
19111950 Ok ( Self :: F64 ( data) )
19121951 }
1952+ ( Self :: F8E4M3 ( storage) , DType :: F64 ) => {
1953+ let data = unary_map ( storage, layout, |v| v. to_f64 ( ) ) ;
1954+ Ok ( Self :: F64 ( data) )
1955+ }
1956+ ( Self :: U8 ( storage) , DType :: F8E4M3 ) => {
1957+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from_f32 ( v as f32 ) ) ;
1958+ Ok ( Self :: F8E4M3 ( data) )
1959+ }
1960+ ( Self :: U32 ( storage) , DType :: F8E4M3 ) => {
1961+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from_f32 ( v as f32 ) ) ;
1962+ Ok ( Self :: F8E4M3 ( data) )
1963+ }
1964+ ( Self :: I64 ( storage) , DType :: F8E4M3 ) => {
1965+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from_f32 ( v as f32 ) ) ;
1966+ Ok ( Self :: F8E4M3 ( data) )
1967+ }
1968+ ( Self :: BF16 ( storage) , DType :: F8E4M3 ) => {
1969+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from ( v. to_f32 ( ) ) ) ;
1970+ Ok ( Self :: F8E4M3 ( data) )
1971+ }
1972+ ( Self :: F16 ( storage) , DType :: F8E4M3 ) => {
1973+ let data = unary_map ( storage, layout, |v| F8E4M3 :: from_f32 ( v. to_f32 ( ) ) ) ;
1974+ Ok ( Self :: F8E4M3 ( data) )
1975+ }
1976+ ( Self :: F32 ( storage) , DType :: F8E4M3 ) => {
1977+ let data = unary_map ( storage, layout, F8E4M3 :: from_f32) ;
1978+ Ok ( Self :: F8E4M3 ( data) )
1979+ }
1980+ ( Self :: F64 ( storage) , DType :: F8E4M3 ) => {
1981+ let data = unary_map ( storage, layout, F8E4M3 :: from_f64) ;
1982+ Ok ( Self :: F8E4M3 ( data) )
1983+ }
1984+ ( Self :: F8E4M3 ( storage) , DType :: F8E4M3 ) => {
1985+ let data = unary_map ( storage, layout, |v| v) ;
1986+ Ok ( Self :: F8E4M3 ( data) )
1987+ }
19131988 }
19141989 }
19151990
@@ -2023,6 +2098,10 @@ impl BackendStorage for CpuStorage {
20232098 let data = unary_map ( storage, layout, |v| v. powf ( e) ) ;
20242099 Ok ( Self :: F64 ( data) )
20252100 }
2101+ Self :: F8E4M3 ( storage) => {
2102+ let data = unary_map ( storage, layout, |v| v. powf ( F8E4M3 :: from_f64 ( e) ) ) ;
2103+ Ok ( Self :: F8E4M3 ( data) )
2104+ }
20262105 Self :: U8 ( _) => Err ( Error :: UnsupportedDTypeForOp ( DType :: U8 , "elu" ) . bt ( ) ) ,
20272106 Self :: U32 ( _) => Err ( Error :: UnsupportedDTypeForOp ( DType :: U32 , "elu" ) . bt ( ) ) ,
20282107 Self :: I64 ( _) => Err ( Error :: UnsupportedDTypeForOp ( DType :: I64 , "elu" ) . bt ( ) ) ,
@@ -2048,6 +2127,10 @@ impl BackendStorage for CpuStorage {
20482127 let data = unary_map ( storage, layout, |v| elu ( v, alpha) ) ;
20492128 Ok ( Self :: F64 ( data) )
20502129 }
2130+ Self :: F8E4M3 ( storage) => {
2131+ let data = unary_map ( storage, layout, |v| elu ( v, F8E4M3 :: from_f64 ( alpha) ) ) ;
2132+ Ok ( Self :: F8E4M3 ( data) )
2133+ }
20512134 Self :: U8 ( _) => Err ( Error :: UnsupportedDTypeForOp ( DType :: U8 , "elu" ) . bt ( ) ) ,
20522135 Self :: U32 ( _) => Err ( Error :: UnsupportedDTypeForOp ( DType :: U32 , "elu" ) . bt ( ) ) ,
20532136 Self :: I64 ( _) => Err ( Error :: UnsupportedDTypeForOp ( DType :: I64 , "elu" ) . bt ( ) ) ,
@@ -2092,6 +2175,15 @@ impl BackendStorage for CpuStorage {
20922175 Ok ( Self :: F64 ( data) )
20932176 }
20942177 }
2178+ Self :: F8E4M3 ( storage) => {
2179+ if B :: F8E4M3_VEC {
2180+ let data = unary_map_vec ( storage, layout, B :: f8e4m3, B :: f8e4m3_vec) ;
2181+ Ok ( Self :: F8E4M3 ( data) )
2182+ } else {
2183+ let data = unary_map ( storage, layout, B :: f8e4m3) ;
2184+ Ok ( Self :: F8E4M3 ( data) )
2185+ }
2186+ }
20952187 Self :: U8 ( storage) => {
20962188 let data = unary_map ( storage, layout, B :: u8) ;
20972189 Ok ( Self :: U8 ( data) )
@@ -2564,6 +2656,7 @@ impl BackendStorage for CpuStorage {
25642656 ( Self :: U8 ( storage) , Scalar :: U8 ( v) ) => set ( storage, l, v) ,
25652657 ( Self :: U32 ( storage) , Scalar :: U32 ( v) ) => set ( storage, l, v) ,
25662658 ( Self :: I64 ( storage) , Scalar :: I64 ( v) ) => set ( storage, l, v) ,
2659+ ( Self :: F8E4M3 ( storage) , Scalar :: F8E4M3 ( v) ) => set ( storage, l, v) ,
25672660 ( st, s) => crate :: bail!(
25682661 "const_set dtype mismatch, expected {:?} but got {:?}" ,
25692662 st. dtype( ) ,
@@ -2632,6 +2725,16 @@ impl BackendDevice for CpuDevice {
26322725 }
26332726 Ok ( CpuStorage :: F16 ( data) )
26342727 }
2728+ DType :: F8E4M3 => {
2729+ let mut data = Vec :: with_capacity ( elem_count) ;
2730+ let uniform =
2731+ rand:: distr:: Uniform :: new ( F8E4M3 :: from_f64 ( min) , F8E4M3 :: from_f64 ( max) )
2732+ . map_err ( Error :: wrap) ?;
2733+ for _i in 0 ..elem_count {
2734+ data. push ( rng. sample :: < F8E4M3 , _ > ( uniform) )
2735+ }
2736+ Ok ( CpuStorage :: F8E4M3 ( data) )
2737+ }
26352738 DType :: F32 => {
26362739 let mut data = Vec :: with_capacity ( elem_count) ;
26372740 let uniform =
@@ -2679,6 +2782,15 @@ impl BackendDevice for CpuDevice {
26792782 }
26802783 Ok ( CpuStorage :: F16 ( data) )
26812784 }
2785+ DType :: F8E4M3 => {
2786+ let mut data = Vec :: with_capacity ( elem_count) ;
2787+ let normal = rand_distr:: Normal :: new ( F8E4M3 :: from_f64 ( mean) , F8E4M3 :: from_f64 ( std) )
2788+ . map_err ( Error :: wrap) ?;
2789+ for _i in 0 ..elem_count {
2790+ data. push ( normal. sample ( & mut rng) )
2791+ }
2792+ Ok ( CpuStorage :: F8E4M3 ( data) )
2793+ }
26822794 DType :: F32 => {
26832795 let mut data = Vec :: with_capacity ( elem_count) ;
26842796 let normal =
@@ -2742,6 +2854,11 @@ impl BackendDevice for CpuDevice {
27422854 v. set_len ( elem_count) ;
27432855 CpuStorage :: F64 ( v)
27442856 }
2857+ DType :: F8E4M3 => {
2858+ let mut v = Vec :: with_capacity ( elem_count) ;
2859+ v. set_len ( elem_count) ;
2860+ CpuStorage :: F8E4M3 ( v)
2861+ }
27452862 } ;
27462863 Ok ( storage)
27472864 }
@@ -2754,6 +2871,7 @@ impl BackendDevice for CpuDevice {
27542871 DType :: I64 => CpuStorage :: I64 ( vec ! [ 0i64 ; elem_count] ) ,
27552872 DType :: BF16 => CpuStorage :: BF16 ( vec ! [ bf16:: ZERO ; elem_count] ) ,
27562873 DType :: F16 => CpuStorage :: F16 ( vec ! [ f16:: ZERO ; elem_count] ) ,
2874+ DType :: F8E4M3 => CpuStorage :: F8E4M3 ( vec ! [ F8E4M3 :: ZERO ; elem_count] ) ,
27572875 DType :: F32 => CpuStorage :: F32 ( vec ! [ 0f32 ; elem_count] ) ,
27582876 DType :: F64 => CpuStorage :: F64 ( vec ! [ 0f64 ; elem_count] ) ,
27592877 } ;
0 commit comments