Skip to content

Commit e5ec49e

Browse files
seanmor5josevalim
andauthored
Add FP8 support (#1507)
Co-authored-by: José Valim <[email protected]>
1 parent 13058bb commit e5ec49e

File tree

13 files changed

+146
-15
lines changed

13 files changed

+146
-15
lines changed

exla/lib/exla/mlir/value.ex

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,7 @@ defmodule EXLA.MLIR.Value do
880880
defp type_number({:pred, 8}), do: "i1"
881881
defp type_number({:s, width}), do: "i#{width}"
882882
defp type_number({:u, width}), do: "ui#{width}"
883+
defp type_number({:f, 8}), do: "f8E5M2"
883884
defp type_number({:f, width}), do: "f#{width}"
884885
defp type_number({:bf, width}), do: "bf#{width}"
885886
defp type_number({:c, 64}), do: "complex<f32>"
@@ -926,12 +927,17 @@ defmodule EXLA.MLIR.Value do
926927
:nan -> type |> Nx.Type.nan_binary() |> native_to_big()
927928
:infinity -> type |> Nx.Type.infinity_binary() |> native_to_big()
928929
:neg_infinity -> type |> Nx.Type.neg_infinity_binary() |> native_to_big()
930+
value when size == 8 -> f8E5M2_to_big(value)
929931
value -> <<value::float-size(size)-big>>
930932
end
931933

932934
Base.encode16(data)
933935
end
934936

937+
defp f8E5M2_to_big(x) do
938+
binary_part(<<x::float-big-16>>, 0, 1)
939+
end
940+
935941
defp native_to_big(binary) do
936942
size = byte_size(binary) * 8
937943
<<value::size(size)-native>> = binary

exla/lib/exla/typespec.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ defmodule EXLA.Typespec do
6969
{:c, 128} => ~c"c128"
7070
}
7171

72+
defp type_to_charlist({:f, 8}), do: ~c"f8e5m2"
73+
defp charlist_to_type(~c"f8"), do: {:f, 8}
74+
7275
for {type, charlist} <- type_to_charlist do
7376
defp charlist_to_type(unquote(charlist)), do: unquote(type)
7477
defp type_to_charlist(unquote(type)), do: unquote(charlist)

exla/test/exla/defn/expr_test.exs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ defmodule EXLA.Defn.ExprTest do
8686
end
8787
end
8888

89+
describe "float8" do
90+
defn return_float8, do: Nx.tensor(1, type: {:f, 8})
91+
92+
test "supports float8 return types" do
93+
assert_equal(return_float8(), Nx.tensor(1, type: {:f, 8}))
94+
end
95+
end
96+
8997
describe "float16" do
9098
defn return_float, do: Nx.tensor(1, type: {:f, 16})
9199

nx/lib/nx.ex

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ defmodule Nx do
5151
5252
* unsigned integers (`u8`, `u16`, `u32`, `u64`)
5353
* signed integers (`s8`, `s16`, `s32`, `s64`)
54-
* floats (`f16`, `f32`, `f64`)
54+
* floats (`f8`, `f16`, `f32`, `f64`)
5555
* brain floats (`bf16`)
5656
* and complex numbers (`c64`, `c128`)
5757
@@ -612,6 +612,15 @@ defmodule Nx do
612612
[1.0, 2.0, 3.0]
613613
>
614614
615+
Certain backends and compilers support 8-bit floats. On the binary
616+
backend this behavior is emulated:
617+
618+
iex> Nx.tensor([1, 2, 3], type: :f8)
619+
#Nx.Tensor<
620+
f8[3]
621+
[1.0, 2.0, 3.0]
622+
>
623+
615624
In all cases, the non-finite values negative infinity (-Inf),
616625
infinity (Inf), and "not a number" (NaN) can be represented by
617626
the atoms `:neg_infinity`, `:infinity`, and `:nan` respectively:
@@ -929,7 +938,7 @@ defmodule Nx do
929938
%T{shape: shape, type: type, names: names, data: %Nx.TemplateBackend{}}
930939
end
931940

932-
for t <- [:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :bf16, :f16, :f32, :f64] do
941+
for t <- [:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :bf16, :f8, :f16, :f32, :f64] do
933942
@doc """
934943
Short-hand function for creating tensor of type `#{t}`.
935944

nx/lib/nx/binary_backend.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2460,8 +2460,9 @@ defmodule Nx.BinaryBackend do
24602460
"expected a number or a scalar tensor of type #{inspect(type)}, got: #{inspect(t)}"
24612461
end
24622462

2463-
defp number_to_binary(number, type),
2464-
do: match_types([type], do: <<write!(number, 0)>>)
2463+
defp number_to_binary(number, type) do
2464+
match_types([type], do: <<write!(number, 0)>>)
2465+
end
24652466

24662467
defp binary_to_number(bin, type) do
24672468
match_types [type] do

nx/lib/nx/constants.ex

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ defmodule Nx.Constants do
2222
2323
## Examples
2424
25+
iex> Nx.Constants.nan({:f, 8})
26+
#Nx.Tensor<
27+
f8
28+
NaN
29+
>
30+
2531
iex> Nx.Constants.nan({:bf, 16})
2632
#Nx.Tensor<
2733
bf16
@@ -66,6 +72,12 @@ defmodule Nx.Constants do
6672
6773
## Examples
6874
75+
iex> Nx.Constants.infinity({:f, 8})
76+
#Nx.Tensor<
77+
f8
78+
Inf
79+
>
80+
6981
iex> Nx.Constants.infinity({:bf, 16})
7082
#Nx.Tensor<
7183
bf16
@@ -110,6 +122,12 @@ defmodule Nx.Constants do
110122
111123
## Examples
112124
125+
iex> Nx.Constants.neg_infinity({:f, 8})
126+
#Nx.Tensor<
127+
f8
128+
-Inf
129+
>
130+
113131
iex> Nx.Constants.neg_infinity({:bf, 16})
114132
#Nx.Tensor<
115133
bf16
@@ -334,6 +352,12 @@ defmodule Nx.Constants do
334352
1.1754943508222875e-38
335353
>
336354
355+
iex> Nx.Constants.smallest_positive_normal(:f8)
356+
#Nx.Tensor<
357+
f8
358+
6.103515625e-5
359+
>
360+
337361
iex> Nx.Constants.smallest_positive_normal({:s, 32})
338362
** (ArgumentError) only floating types are supported, got: {:s, 32}
339363
"""
@@ -377,6 +401,12 @@ defmodule Nx.Constants do
377401
0.0078125
378402
>
379403
404+
iex> Nx.Constants.epsilon(:f8)
405+
#Nx.Tensor<
406+
f8
407+
0.25
408+
>
409+
380410
iex> Nx.Constants.epsilon({:s, 32})
381411
** (ArgumentError) only floating types are supported, got: {:s, 32}
382412
"""
@@ -423,6 +453,12 @@ defmodule Nx.Constants do
423453
3.140625
424454
>
425455
456+
iex> Nx.Constants.pi({:f, 8})
457+
#Nx.Tensor<
458+
f8
459+
3.0
460+
>
461+
426462
iex> Nx.Constants.pi({:s, 32})
427463
** (ArgumentError) only floating types are supported, got: {:s, 32}
428464
"""
@@ -469,6 +505,12 @@ defmodule Nx.Constants do
469505
2.703125
470506
>
471507
508+
iex> Nx.Constants.e({:f, 8})
509+
#Nx.Tensor<
510+
f8
511+
2.5
512+
>
513+
472514
iex> Nx.Constants.e({:s, 32})
473515
** (ArgumentError) only floating types are supported, got: {:s, 32}
474516
"""
@@ -515,6 +557,12 @@ defmodule Nx.Constants do
515557
0.57421875
516558
>
517559
560+
iex> Nx.Constants.euler_gamma({:f, 8})
561+
#Nx.Tensor<
562+
f8
563+
0.5
564+
>
565+
518566
iex> Nx.Constants.euler_gamma({:s, 32})
519567
** (ArgumentError) only floating types are supported, got: {:s, 32}
520568
"""

nx/lib/nx/random.ex

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ defmodule Nx.Random do
299299
deftransformp mantissa_shift(nbits, type) do
300300
mantissa =
301301
case type do
302+
{:f, 8} -> 2
302303
{:bf, 16} -> 7
303304
{:f, 16} -> 10
304305
{:f, 32} -> 23

nx/lib/nx/shared.ex

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,14 @@ defmodule Nx.Shared do
105105
quote do: Nx.Shared.read_bf16(unquote(var))
106106
end
107107

108+
defp read_bin_modifier(var, :f, 8) do
109+
quote do: Nx.Shared.read_f8(unquote(var))
110+
end
111+
108112
defp read_bin_modifier(var, :f, size) do
109113
quote do
110114
case unquote(var) do
115+
_ when unquote(size) == 8 -> Nx.Shared.read_f8(unquote(var))
111116
<<var::float-native-size(unquote(size))>> -> var
112117
var -> Nx.Shared.read_non_finite(var, unquote(size))
113118
end
@@ -122,14 +127,14 @@ defmodule Nx.Shared do
122127
quote do
123128
case unquote(var) do
124129
x when is_number(x) -> binary_part(<<x::float-native-32>>, 2, 2)
125-
x -> Nx.Shared.write_bf16(x)
130+
x -> Nx.Shared.write_non_finite_bf16(x)
126131
end :: binary
127132
end
128133
else
129134
quote do
130135
case unquote(var) do
131136
x when is_number(x) -> binary_part(<<x::float-native-32>>, 0, 2)
132-
x -> Nx.Shared.write_bf16(x)
137+
x -> Nx.Shared.write_non_finite_bf16(x)
133138
end :: binary
134139
end
135140
end
@@ -155,7 +160,8 @@ defmodule Nx.Shared do
155160
defp write_bin_modifier(var, :f, size) do
156161
quote do
157162
case unquote(var) do
158-
x when is_number(x) -> <<x::float-native-size(unquote(size))>>
163+
x when is_number(x) and unquote(size) != 8 -> <<x::float-native-size(unquote(size))>>
164+
x when is_number(x) -> Nx.Shared.write_finite_f8(unquote(var))
159165
x -> Nx.Shared.write_non_finite(x, unquote(size))
160166
end :: binary
161167
end
@@ -192,6 +198,22 @@ defmodule Nx.Shared do
192198
end
193199
end
194200

201+
@doc """
202+
F8 read callback.
203+
"""
204+
def read_f8(<<0xFC::8-native>>), do: :neg_infinity
205+
def read_f8(<<0x7C::8-native>>), do: :infinity
206+
def read_f8(<<_sign::1, 31::5, mantissa::2>>) when mantissa != 0, do: :nan
207+
208+
def read_f8(<<sign::1, exp::5, mantissa::2>>) do
209+
float = :math.pow(2, exp - 15) * (1 + mantissa / 4)
210+
211+
case sign do
212+
0 -> float
213+
_ -> -float
214+
end
215+
end
216+
195217
@doc """
196218
C64 and C128 callback.
197219
"""
@@ -217,14 +239,24 @@ defmodule Nx.Shared do
217239
@doc """
218240
BF16 write callback.
219241
"""
220-
def write_bf16(data) do
242+
def write_non_finite_bf16(data) do
221243
case data do
222244
:infinity -> unquote(Nx.Type.infinity_binary({:bf, 16}))
223245
:neg_infinity -> unquote(Nx.Type.neg_infinity_binary({:bf, 16}))
224246
:nan -> unquote(Nx.Type.nan_binary({:bf, 16}))
225247
end
226248
end
227249

250+
if System.endianness() == :little do
251+
def write_finite_f8(x) do
252+
binary_part(<<x::float-native-16>>, 1, 1)
253+
end
254+
else
255+
def write_finite_f8(x) do
256+
binary_part(<<x::float-native-16>>, 0, 1)
257+
end
258+
end
259+
228260
@doc """
229261
Complex write callback.
230262
"""
@@ -247,6 +279,14 @@ defmodule Nx.Shared do
247279
@doc """
248280
Non-finite read callback.
249281
"""
282+
def read_non_finite(data, 8) do
283+
case data do
284+
<<0xFC::8-native>> -> :neg_infinity
285+
<<0x7C::8-native>> -> :infinity
286+
_ -> :nan
287+
end
288+
end
289+
250290
def read_non_finite(data, 16) do
251291
case data do
252292
<<0xFC00::16-native>> -> :neg_infinity
@@ -274,7 +314,7 @@ defmodule Nx.Shared do
274314
@doc """
275315
Non-finite write callback.
276316
"""
277-
for size <- [16, 32, 64] do
317+
for size <- [8, 16, 32, 64] do
278318
def write_non_finite(data, unquote(size)) do
279319
case data do
280320
:infinity -> unquote(Nx.Type.infinity_binary({:f, size}))

0 commit comments

Comments
 (0)