Skip to content

Commit b3f280b

Browse files
authored
Add FP8 dtype support (E4M3FN and E5M2) (#15)
1 parent 4718346 commit b3f280b

File tree

4 files changed

+56
-5
lines changed

4 files changed

+56
-5
lines changed

lib/safetensors.ex

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ defmodule Safetensors do
3131
{:f, 64} => "F64",
3232
{:f, 32} => "F32",
3333
{:f, 16} => "F16",
34+
{:f, 8} => "F8_E5M2",
35+
{:f8_e4m3fn, 8} => "F8_E4M3",
3436
{:s, 64} => "I64",
3537
{:s, 32} => "I32",
3638
{:s, 16} => "I16",
@@ -41,7 +43,7 @@ defmodule Safetensors do
4143
{:u, 8} => "U8"
4244
}
4345

44-
@dtype_to_type for {k, v} <- @type_to_dtype, into: %{}, do: {v, k}
46+
@dtype_to_type for({k, v} <- @type_to_dtype, into: %{}, do: {v, k})
4547

4648
@doc """
4749
Writes a map of tensors to a file.
@@ -94,8 +96,8 @@ defmodule Safetensors do
9496
end
9597

9698
defp tensor_byte_size(tensor) do
97-
{_, elem_size} = Nx.type(tensor)
98-
elem_byte_size = div(elem_size, 8)
99+
{_, size} = Nx.type(tensor)
100+
elem_byte_size = div(size, 8)
99101
Nx.size(tensor) * elem_byte_size
100102
end
101103

mix.exs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ defmodule Safetensors.MixProject do
2525
defp deps do
2626
[
2727
{:jason, "~> 1.4"},
28-
{:nx, "~> 0.5"},
28+
# TODO: Switch to released version once Nx with fp8 support is published
29+
{:nx, github: "elixir-nx/nx", sparse: "nx", branch: "main"},
2930
{:ex_doc, "~> 0.37", only: :dev, runtime: false}
3031
]
3132
end

mix.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
"makeup_elixir": {:hex, :makeup_elixir, "1.0.1", "e928a4f984e795e41e3abd27bfc09f51db16ab8ba1aebdba2b3a575437efafc2", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "7284900d412a3e5cfd97fdaed4f5ed389b8f2b4cb49efc0eb3bd10e2febf9507"},
88
"makeup_erlang": {:hex, :makeup_erlang, "1.0.2", "03e1804074b3aa64d5fad7aa64601ed0fb395337b982d9bcf04029d68d51b6a7", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "af33ff7ef368d5893e4a267933e7744e46ce3cf1f61e2dccf53a111ed3aa3727"},
99
"nimble_parsec": {:hex, :nimble_parsec, "1.4.2", "8efba0122db06df95bfaa78f791344a89352ba04baedd3849593bfce4d0dc1c6", [:mix], [], "hexpm", "4b21398942dda052b403bbe1da991ccd03a053668d147d53fb8c4e0efe09c973"},
10-
"nx": {:hex, :nx, "0.9.2", "17563029c01bf749aad3c31234326d7665abd0acc33ee2acbe531a4759f29a8a", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "914d74741617d8103de8ab1f8c880353e555263e1c397b8a1109f79a3716557f"},
10+
"nx": {:git, "https://github.com/elixir-nx/nx.git", "04fe0ecf30cc20494f034f29fa3c07a3db7dd8c3", [sparse: "nx", branch: "main"]},
1111
"telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"},
1212
}

test/safetensors_test.exs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,52 @@ defmodule SafetensorsTest do
7878

7979
assert Safetensors.load!(serialized) == %{"test1" => Nx.tensor([[0, 0], [0, 0]], type: :s32)}
8080
end
81+
82+
@tag :tmp_dir
83+
test "write f8_e4m3fn", %{tmp_dir: tmp_dir} do
84+
path = Path.join(tmp_dir, "safetensor")
85+
86+
data = %{test: Nx.tensor([[1.0, 2.0], [3.0, 4.0]], type: :f8_e4m3fn)}
87+
Safetensors.write!(path, data)
88+
89+
assert File.read!(path) ==
90+
~s(?\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"F8_E4M3","shape":[2,2],"data_offsets":[0,4]}}\x38\x40\x44\x48)
91+
end
92+
93+
@tag :tmp_dir
94+
test "read f8_e4m3fn", %{tmp_dir: tmp_dir} do
95+
path = Path.join(tmp_dir, "safetensor")
96+
97+
File.write!(
98+
path,
99+
~s(?\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"F8_E4M3","shape":[2,2],"data_offsets":[0,4]}}\x38\x40\x44\x48)
100+
)
101+
102+
assert Safetensors.read!(path) == %{
103+
"test" => Nx.tensor([[1.0, 2.0], [3.0, 4.0]], type: :f8_e4m3fn)
104+
}
105+
end
106+
107+
@tag :tmp_dir
108+
test "write f8_e5m2", %{tmp_dir: tmp_dir} do
109+
path = Path.join(tmp_dir, "safetensor")
110+
111+
data = %{test: Nx.tensor([[1.0, 2.0], [4.0, 8.0]], type: :f8)}
112+
Safetensors.write!(path, data)
113+
114+
assert File.read!(path) ==
115+
~s(?\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"F8_E5M2","shape":[2,2],"data_offsets":[0,4]}}\x3C\x40\x44\x48)
116+
end
117+
118+
@tag :tmp_dir
119+
test "read f8_e5m2", %{tmp_dir: tmp_dir} do
120+
path = Path.join(tmp_dir, "safetensor")
121+
122+
File.write!(
123+
path,
124+
~s(?\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"F8_E5M2","shape":[2,2],"data_offsets":[0,4]}}\x3C\x40\x44\x48)
125+
)
126+
127+
assert Safetensors.read!(path) == %{"test" => Nx.tensor([[1.0, 2.0], [4.0, 8.0]], type: :f8)}
128+
end
81129
end

0 commit comments

Comments
 (0)