Skip to content

Commit 08e3330

Browse files
committed
Fix encoding of bf16
1 parent d31c33e commit 08e3330

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

exla/lib/exla/mlir/value.ex

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -921,13 +921,14 @@ defmodule EXLA.MLIR.Value do
921921
end
922922
end
923923

924-
defp float_hex(value, {_, size} = type) do
924+
defp float_hex(value, {mod, size} = type) do
925925
data =
926926
case value do
927927
:nan -> type |> Nx.Type.nan_binary() |> native_to_big()
928928
:infinity -> type |> Nx.Type.infinity_binary() |> native_to_big()
929929
:neg_infinity -> type |> Nx.Type.neg_infinity_binary() |> native_to_big()
930930
value when size == 8 -> f8E5M2_to_big(value)
931+
value when mod == :bf and size == 16 -> bf16_to_big(value)
931932
value -> <<value::float-size(size)-big>>
932933
end
933934

@@ -938,6 +939,10 @@ defmodule EXLA.MLIR.Value do
938939
binary_part(<<x::float-big-16>>, 0, 1)
939940
end
940941

942+
defp bf16_to_big(x) do
943+
binary_part(<<x::float-big-32>>, 0, 2)
944+
end
945+
941946
defp native_to_big(binary) do
942947
size = byte_size(binary) * 8
943948
<<value::size(size)-native>> = binary

exla/test/exla/defn/expr_test.exs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,23 @@ defmodule EXLA.Defn.ExprTest do
8686
end
8787
end
8888

89-
describe "float8" do
90-
defn return_float8, do: Nx.tensor(1, type: {:f, 8})
89+
describe "types" do
90+
defn return_f8, do: Nx.tensor(1, type: {:f, 8})
9191

92-
test "supports float8 return types" do
93-
assert_equal(return_float8(), Nx.tensor(1, type: {:f, 8}))
92+
test "f8" do
93+
assert_equal(return_f8(), Nx.tensor(1, type: {:f, 8}))
94+
end
95+
96+
defn return_f16, do: Nx.tensor(1, type: {:f, 16})
97+
98+
test "f16" do
99+
assert_equal(return_f16(), Nx.tensor(1, type: {:f, 16}))
94100
end
95-
end
96101

97-
describe "float16" do
98-
defn return_float, do: Nx.tensor(1, type: {:f, 16})
102+
defn return_bf16, do: Nx.tensor(1, type: {:bf, 16})
99103

100-
test "supports float16 return types" do
101-
assert_equal(return_float(), Nx.tensor(1, type: {:f, 16}))
104+
test "bf16" do
105+
assert_equal(return_bf16(), Nx.tensor(1, type: {:bf, 16}))
102106
end
103107
end
104108

0 commit comments

Comments
 (0)