Skip to content

Commit 8a9c2b3

Browse files
Add quantized int types (#1528)
Co-authored-by: Jonatan Kłosko <[email protected]>
1 parent ad28ea7 commit 8a9c2b3

File tree

13 files changed

+324
-79
lines changed

13 files changed

+324
-79
lines changed

exla/c_src/exla/exla.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -505,17 +505,13 @@ ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM a
505505
return exla::nif::error(env, "Bad argument count.");
506506
}
507507

508-
ErlNifBinary bin;
509508
xla::Shape shape;
510509
exla::ExlaClient** client;
511510
int device_id;
512511

513512
if (!exla::nif::get<exla::ExlaClient*>(env, argv[0], client)) {
514513
return exla::nif::error(env, "Unable to get client.");
515514
}
516-
if (!exla::nif::get_binary(env, argv[1], &bin)) {
517-
return exla::nif::error(env, "Unable to get data.");
518-
}
519515
if (!exla::nif::get_typespec_as_xla_shape(env, argv[2], &shape)) {
520516
return exla::nif::error(env, "Unable to get shape.");
521517
}

exla/c_src/exla/exla_client.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,11 @@ xla::StatusOr<std::unique_ptr<xla::PjRtBuffer>> PjRtBufferFromBinary(xla::PjRtCl
6666
std::function<void()> on_done_with_host_buffer = [copy_env]() { enif_free_env(copy_env); };
6767

6868
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));
69+
// Passing std::nullopt should work, but it fails for subbyte types,
70+
// so we build the default strides. See https://github.com/openxla/xla/issues/16795
71+
auto byte_strides = xla::ShapeUtil::ByteStrides(shape);
6972
EXLA_ASSIGN_OR_RETURN(auto buffer, client->BufferFromHostBuffer(
70-
binary.data, shape.element_type(), shape.dimensions(), std::nullopt, semantics, on_done_with_host_buffer, device));
73+
binary.data, shape.element_type(), shape.dimensions(), byte_strides, semantics, on_done_with_host_buffer, device));
7174

7275
return std::move(buffer);
7376
}

exla/lib/exla/backend.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ defmodule EXLA.Backend do
187187

188188
@impl true
189189
def to_binary(%T{data: %B{buffer: buffer}, type: {_, size}}, limit) do
190+
# Subbyte elements are read as individual bytes
191+
size = max(size, 8)
190192
EXLA.DeviceBuffer.read(buffer, limit * div(size, 8))
191193
end
192194

exla/lib/exla/device_buffer.ex

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,21 @@ defmodule EXLA.DeviceBuffer do
1818
Places the given binary `data` on the given `device` using `client`.
1919
"""
2020
def place_on_device(data, %EXLA.Typespec{} = typespec, client = %Client{}, device_id)
21-
when is_integer(device_id) and is_binary(data) do
21+
when is_integer(device_id) and is_bitstring(data) do
22+
# At the moment XLA does not support allocating a packed buffer,
23+
# so we unpack subbyte elements into their own bytes
24+
data =
25+
case typespec.type do
26+
{:u, size} when size in [2, 4] ->
27+
for <<x::native-size(size) <- data>>, into: <<>>, do: <<x::native-8>>
28+
29+
{:s, size} when size in [2, 4] ->
30+
for <<x::native-signed-size(size) <- data>>, into: <<>>, do: <<x::native-signed-8>>
31+
32+
_ ->
33+
data
34+
end
35+
2236
ref =
2337
client.ref
2438
|> EXLA.NIF.binary_to_device_mem(data, EXLA.Typespec.nif_encode(typespec), device_id)
@@ -47,8 +61,21 @@ defmodule EXLA.DeviceBuffer do
4761
without destroying it. If `size` is negative, then it
4862
reads the whole buffer.
4963
"""
50-
def read(%DeviceBuffer{ref: ref}, size \\ -1) do
51-
EXLA.NIF.read_device_mem(ref, size) |> unwrap!()
64+
def read(%DeviceBuffer{ref: ref, typespec: typespec}, size \\ -1) do
65+
data = EXLA.NIF.read_device_mem(ref, size) |> unwrap!()
66+
67+
# At the moment XLA does not support reading a packed buffer,
68+
# so we pack the elements ourselves
69+
case typespec.type do
70+
{:u, size} when size in [2, 4] ->
71+
for <<x::native-8 <- data>>, into: <<>>, do: <<x::native-size(size)>>
72+
73+
{:s, size} when size in [2, 4] ->
74+
for <<x::native-signed-8 <- data>>, into: <<>>, do: <<x::native-signed-size(size)>>
75+
76+
_ ->
77+
data
78+
end
5279
end
5380

5481
@doc """

exla/lib/exla/typespec.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,14 @@ defmodule EXLA.Typespec do
5353
type_to_charlist = %{
5454
:token => ~c"token",
5555
{:pred, 8} => ~c"pred",
56+
{:s, 2} => ~c"s2",
57+
{:s, 4} => ~c"s4",
5658
{:s, 8} => ~c"s8",
5759
{:s, 16} => ~c"s16",
5860
{:s, 32} => ~c"s32",
5961
{:s, 64} => ~c"s64",
62+
{:u, 2} => ~c"u2",
63+
{:u, 4} => ~c"u4",
6064
{:u, 8} => ~c"u8",
6165
{:u, 16} => ~c"u16",
6266
{:u, 32} => ~c"u32",

exla/test/exla/backend_test.exs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,76 @@ defmodule EXLA.BackendTest do
197197
assert inspect(Nx.conjugate(~VEC[1 2-0i 3+0i 0-i 0-2i])) =~
198198
"1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i"
199199
end
200+
201+
describe "quantized types" do
202+
test "s2" do
203+
tensor = Nx.s2(-1)
204+
assert <<-1::2-signed-native>> = Nx.to_binary(tensor)
205+
206+
tensor = Nx.s2([-2, -1, 1])
207+
assert tensor.type == {:s, 2}
208+
209+
assert <<-2::2-signed-native, -1::2-signed-native, 1::2-signed-native>> =
210+
Nx.to_binary(tensor)
211+
212+
assert [-2, -1, 1] = Nx.to_flat_list(tensor)
213+
assert 0 = Nx.byte_size(tensor)
214+
assert 6 = Nx.bit_size(tensor)
215+
216+
tensor = Nx.s2([-2, -1, 0, 1, 0, -1, -2])
217+
assert 1 = Nx.byte_size(tensor)
218+
assert 14 = Nx.bit_size(tensor)
219+
end
220+
221+
test "s4" do
222+
tensor = Nx.s4(-1)
223+
assert <<-1::4-signed-native>> = Nx.to_binary(tensor)
224+
225+
tensor = Nx.s4([-8, -1, 7])
226+
assert tensor.type == {:s, 4}
227+
228+
assert <<-8::4-signed-native, -1::4-signed-native, 7::4-signed-native>> =
229+
Nx.to_binary(tensor)
230+
231+
assert [-8, -1, 7] = Nx.to_flat_list(tensor)
232+
assert 1 = Nx.byte_size(tensor)
233+
assert 12 = Nx.bit_size(tensor)
234+
235+
tensor = Nx.s4([-8, -3, 0, 7, 0, -3, -8])
236+
assert 3 = Nx.byte_size(tensor)
237+
assert 28 = Nx.bit_size(tensor)
238+
end
239+
240+
test "u2" do
241+
tensor = Nx.u2(1)
242+
assert <<1::2-native>> = Nx.to_binary(tensor)
243+
244+
tensor = Nx.u2([1, 2, 3])
245+
assert tensor.type == {:u, 2}
246+
assert <<1::2-native, 2::2-native, 3::2-native>> = Nx.to_binary(tensor)
247+
assert [1, 2, 3] = Nx.to_flat_list(tensor)
248+
assert 0 = Nx.byte_size(tensor)
249+
assert 6 = Nx.bit_size(tensor)
250+
251+
tensor = Nx.u2([0, 1, 2, 3, 2, 1, 0])
252+
assert 1 = Nx.byte_size(tensor)
253+
assert 14 = Nx.bit_size(tensor)
254+
end
255+
256+
test "u4" do
257+
tensor = Nx.u4(1)
258+
assert <<1::4-native>> = Nx.to_binary(tensor)
259+
260+
tensor = Nx.u4([0, 7, 15])
261+
assert tensor.type == {:u, 4}
262+
assert <<0::4-native, 7::4-native, 15::4-native>> = Nx.to_binary(tensor)
263+
assert [0, 7, 15] = Nx.to_flat_list(tensor)
264+
assert 1 = Nx.byte_size(tensor)
265+
assert 12 = Nx.bit_size(tensor)
266+
267+
tensor = Nx.u4([0, 1, 2, 3, 13, 14, 15])
268+
assert 3 = Nx.byte_size(tensor)
269+
assert 28 = Nx.bit_size(tensor)
270+
end
271+
end
200272
end

nx/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
Nx is a multi-dimensional tensors library for Elixir with multi-staged compilation to the CPU/GPU. Its high-level features are:
66

7-
* Typed multi-dimensional tensors, where the tensors can be unsigned integers (`u8`, `u16`, `u32`, `u64`), signed integers (`s8`, `s16`, `s32`, `s64`), floats (`f16`, `f32`, `f64`), brain floats (`bf16`), and complex numbers (`c64`, `c128`);
7+
* Typed multi-dimensional tensors, where the tensors can be unsigned integers (`u2`, `u4`, `u8`, `u16`, `u32`, `u64`), signed integers (`s2`, `s4`, `s8`, `s16`, `s32`, `s64`), floats (`f16`, `f32`, `f64`), brain floats (`bf16`), and complex numbers (`c64`, `c128`);
88

99
* Named tensors, allowing developers to give names to each dimension, leading to more readable and less error prone codebases;
1010

nx/guides/intro-to-nx.livemd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ libraries that support those tensors. Nx has three primary capabilities:
2424
such as machine learning, simulations, curve fitting, and probabilistic models.
2525

2626
Here's more about each of those capabilities. Nx [tensors]() can hold
27-
unsigned integers (u8, u16, u32, u64),
28-
signed integers (s8, s16, s32, s64),
27+
unsigned integers (u2, u4, u8, u16, u32, u64),
28+
signed integers (s2, s4s8, s16, s32, s64),
2929
floats (f32, f64), brain floats (bf16), and complex (c64, c128).
3030
Tensors support backends implemented outside of Elixir, including Google's
3131
Accelerated Linear Algebra (XLA) and LibTorch.

nx/lib/nx.ex

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ defmodule Nx do
4949
5050
The tensor types can be one of:
5151
52-
* unsigned integers (`u8`, `u16`, `u32`, `u64`)
53-
* signed integers (`s8`, `s16`, `s32`, `s64`)
52+
* unsigned integers (`u2`, `u4`, `u8`, `u16`, `u32`, `u64`)
53+
* signed integers (`s2`, `s4`, `s8`, `s16`, `s32`, `s64`)
5454
* floats (`f8`, `f16`, `f32`, `f64`)
5555
* brain floats (`bf16`)
5656
* and complex numbers (`c64`, `c128`)
@@ -431,6 +431,7 @@ defmodule Nx do
431431

432432
import Nx.Shared
433433
import Nx.Defn.Kernel, only: [keyword!: 2]
434+
import Kernel, except: [bit_size: 1]
434435

435436
alias Nx.Tensor, as: T
436437

@@ -855,7 +856,7 @@ defmodule Nx do
855856
{dimensions, acc} = flatten_list(list, type, [], [])
856857

857858
{dimensions |> Enum.reverse() |> List.to_tuple(),
858-
acc |> Enum.reverse() |> :erlang.list_to_binary()}
859+
acc |> Enum.reverse() |> :erlang.list_to_bitstring()}
859860
end
860861

861862
defp flatten_list([], _type, dimensions, acc) do
@@ -940,7 +941,9 @@ defmodule Nx do
940941
%T{shape: shape, type: type, names: names, data: %Nx.TemplateBackend{}}
941942
end
942943

943-
for t <- [:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :bf16, :f8, :f16, :f32, :f64] do
944+
for t <-
945+
[:u2, :u4, :u8, :u16, :u32, :u64, :s2, :s4, :s8, :s16, :s32, :s64] ++
946+
[:f8, :bf16, :f16, :f32, :f64] do
944947
@doc """
945948
Short-hand function for creating tensor of type `#{t}`.
946949
@@ -1971,13 +1974,13 @@ defmodule Nx do
19711974
def from_binary(binary, type, opts \\ []) when is_binary(binary) do
19721975
opts = keyword!(opts, [:backend])
19731976
{_, size} = type = Nx.Type.normalize!(type)
1974-
dim = div(bit_size(binary), size)
1977+
dim = div(Kernel.bit_size(binary), size)
19751978

19761979
if binary == "" do
19771980
raise ArgumentError, "cannot build an empty tensor"
19781981
end
19791982

1980-
if rem(bit_size(binary), size) != 0 do
1983+
if rem(Kernel.bit_size(binary), size) != 0 do
19811984
raise ArgumentError, "binary does not match the given size"
19821985
end
19831986

@@ -1990,17 +1993,26 @@ defmodule Nx do
19901993
@doc """
19911994
Returns the underlying tensor as a binary.
19921995
1993-
**Warning**: converting a tensor to a binary can
1994-
potentially be a very expensive operation, as it
1995-
may copy a GPU tensor fully to the machine memory.
1996-
19971996
It returns the in-memory binary representation of
19981997
the tensor in a row-major fashion. The binary is
19991998
in the system endianness, which has to be taken into
20001999
account if the binary is meant to be serialized to
20012000
other systems.
20022001
2003-
Note: This function cannot be used in `defn`.
2002+
This function cannot be used in `defn`.
2003+
2004+
> ### Potentially expensive operation {: .warning}
2005+
>
2006+
> Converting a tensor to a binary can potentially be a very
2007+
> expensive operation, as it may copy a GPU tensor fully to
2008+
> the machine memory.
2009+
2010+
> ### Binaries vs bitstrings {: .info}
2011+
>
2012+
> If a tensor of type u2/u4/s2/s4 is given to this function,
2013+
> this function may not return a binary (where the number of bits
2014+
> is divisible by 8) but rather a bitstring (where the number of
2015+
> bits may not be divisible by 8).
20042016
20052017
## Options
20062018
@@ -4286,6 +4298,10 @@ defmodule Nx do
42864298
Returns the byte size of the data in the tensor
42874299
computed from its shape and type.
42884300
4301+
If the tensor has s2/s4/u2/u4 types, the value
4302+
will be rounded down. Consider using `bit_size/1`
4303+
instead.
4304+
42894305
## Examples
42904306
42914307
iex> Nx.byte_size(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
@@ -4304,9 +4320,33 @@ defmodule Nx do
43044320
43054321
"""
43064322
@doc type: :shape
4307-
def byte_size(tensor) do
4323+
def byte_size(tensor), do: div(bit_size(tensor), 8)
4324+
4325+
@doc """
4326+
Returns the bit size of the data in the tensor
4327+
computed from its shape and type.
4328+
4329+
## Examples
4330+
4331+
iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
4332+
192
4333+
iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [4, 5, 6]], type: :u8))
4334+
48
4335+
iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [3, 2, 1]], type: :u2))
4336+
12
4337+
iex> Nx.bit_size(1)
4338+
32
4339+
4340+
Vectorized tensors account for all elements
4341+
4342+
iex> Nx.bit_size(Nx.tensor([[1, 2], [3, 4]]) |> Nx.vectorize(:x))
4343+
128
4344+
4345+
"""
4346+
@doc type: :shape
4347+
def bit_size(tensor) do
43084348
%{type: {_, bit_size}} = tensor = to_tensor(tensor)
4309-
flat_size(tensor) * div(bit_size, 8)
4349+
flat_size(tensor) * bit_size
43104350
end
43114351

43124352
@doc """
@@ -15466,9 +15506,9 @@ defmodule Nx do
1546615506
defp do_numpy_to_tensor(rest, header_size) when is_binary(rest) do
1546715507
<<header::size(header_size)-binary, array::binary>> = rest
1546815508
{byte_order, {_, size} = type, shape, fortran_order?} = parse_header(header)
15469-
byte_size_of_array = div(size, 8) * Nx.size(shape)
15509+
bit_size_of_array = size * Nx.size(shape)
1547015510

15471-
<<data::size(byte_size_of_array)-binary>> = array
15511+
<<data::size(bit_size_of_array)-bitstring>> = array
1547215512

1547315513
data
1547415514
|> new_byte_order(size, byte_order)

0 commit comments

Comments
 (0)