@@ -78,4 +78,52 @@ defmodule SafetensorsTest do
7878
7979 assert Safetensors . load! ( serialized ) == % { "test1" => Nx . tensor ( [ [ 0 , 0 ] , [ 0 , 0 ] ] , type: :s32 ) }
8080 end
81+
82+ @ tag :tmp_dir
83+ test "write f8_e4m3fn" , % { tmp_dir: tmp_dir } do
84+ path = Path . join ( tmp_dir , "safetensor" )
85+
86+ data = % { test: Nx . tensor ( [ [ 1.0 , 2.0 ] , [ 3.0 , 4.0 ] ] , type: :f8_e4m3fn ) }
87+ Safetensors . write! ( path , data )
88+
89+ assert File . read! ( path ) ==
90+ ~s( ?\x00 \x00 \x00 \x00 \x00 \x00 \x00 {"test":{"dtype":"F8_E4M3","shape":[2,2],"data_offsets":[0,4]}}\x38 \x40 \x44 \x48 )
91+ end
92+
93+ @ tag :tmp_dir
94+ test "read f8_e4m3fn" , % { tmp_dir: tmp_dir } do
95+ path = Path . join ( tmp_dir , "safetensor" )
96+
97+ File . write! (
98+ path ,
99+ ~s( ?\x00 \x00 \x00 \x00 \x00 \x00 \x00 {"test":{"dtype":"F8_E4M3","shape":[2,2],"data_offsets":[0,4]}}\x38 \x40 \x44 \x48 )
100+ )
101+
102+ assert Safetensors . read! ( path ) == % {
103+ "test" => Nx . tensor ( [ [ 1.0 , 2.0 ] , [ 3.0 , 4.0 ] ] , type: :f8_e4m3fn )
104+ }
105+ end
106+
107+ @ tag :tmp_dir
108+ test "write f8_e5m2" , % { tmp_dir: tmp_dir } do
109+ path = Path . join ( tmp_dir , "safetensor" )
110+
111+ data = % { test: Nx . tensor ( [ [ 1.0 , 2.0 ] , [ 4.0 , 8.0 ] ] , type: :f8 ) }
112+ Safetensors . write! ( path , data )
113+
114+ assert File . read! ( path ) ==
115+ ~s( ?\x00 \x00 \x00 \x00 \x00 \x00 \x00 {"test":{"dtype":"F8_E5M2","shape":[2,2],"data_offsets":[0,4]}}\x3C \x40 \x44 \x48 )
116+ end
117+
118+ @ tag :tmp_dir
119+ test "read f8_e5m2" , % { tmp_dir: tmp_dir } do
120+ path = Path . join ( tmp_dir , "safetensor" )
121+
122+ File . write! (
123+ path ,
124+ ~s( ?\x00 \x00 \x00 \x00 \x00 \x00 \x00 {"test":{"dtype":"F8_E5M2","shape":[2,2],"data_offsets":[0,4]}}\x3C \x40 \x44 \x48 )
125+ )
126+
127+ assert Safetensors . read! ( path ) == % { "test" => Nx . tensor ( [ [ 1.0 , 2.0 ] , [ 4.0 , 8.0 ] ] , type: :f8 ) }
128+ end
81129end
0 commit comments