Skip to content

Commit 23081e3

Browse files
authored
Merge pull request #12 from elixir-nx/tm-backend-ops
`reshape` and `squeeze` implementations on `Ortex.Backend`
2 parents 88626bb + b5fc07c commit 23081e3

File tree

7 files changed

+203
-3
lines changed

7 files changed

+203
-3
lines changed

lib/ortex/backend.ex

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ defmodule Ortex.Backend do
7171
@impl true
7272
def inspect(%T{} = tensor, inspect_opts) do
7373
limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1
74-
# IO.inspect(limit)
7574

7675
tensor
7776
|> to_binary(min(limit, Nx.size(tensor)))
@@ -82,7 +81,30 @@ defmodule Ortex.Backend do
8281
@impl true
8382
def slice(out, %T{data: %B{ref: tensor_ref}}, start_indicies, lengths, strides) do
8483
r = Ortex.Native.slice(tensor_ref, start_indicies, lengths, strides)
85-
put_in(out.data, %Ortex.Backend{ref: r})
84+
put_in(out.data, %B{ref: r})
85+
end
86+
87+
@impl true
88+
def reshape(out, %T{data: %B{ref: ref}}) do
89+
shape = Nx.shape(out) |> Tuple.to_list()
90+
put_in(out.data, %B{ref: Ortex.Native.reshape(ref, shape)})
91+
end
92+
93+
@impl true
94+
def squeeze(out, tensor, axes) do
95+
%T{shape: old_shape, names: names, data: %B{ref: ref}} = tensor
96+
{new_shape, new_names} = Nx.Shape.squeeze(old_shape, axes, names)
97+
98+
if old_shape == new_shape do
99+
%{out | data: %B{ref: ref}}
100+
else
101+
%{
102+
out
103+
| shape: new_shape,
104+
names: new_names,
105+
data: %B{ref: Ortex.Native.reshape(ref, new_shape |> Tuple.to_list())}
106+
}
107+
end
86108
end
87109

88110
if Application.compile_env(:ortex, :add_backend_on_inspect, true) do

lib/ortex/native.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,6 @@ defmodule Ortex.Native do
2424

2525
def slice(_tensor, _start_indicies, _lengths, _strides),
2626
do: :erlang.nif_error(:nif_not_loaded)
27+
28+
def reshape(_tensor, _shape), do: :erlang.nif_error(:nif_not_loaded)
2729
end

native/ortex/src/lib.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,26 @@ pub fn slice<'a>(
8484
strides,
8585
)))
8686
}
87+
88+
#[rustler::nif]
89+
pub fn reshape<'a>(
90+
tensor: ResourceArc<OrtexTensor>,
91+
shape: Vec<usize>,
92+
) -> NifResult<ResourceArc<OrtexTensor>> {
93+
Ok(ResourceArc::new(tensor.reshape(shape)?))
94+
}
95+
8796
rustler::init!(
8897
"Elixir.Ortex.Native",
89-
[run, init, from_binary, to_binary, show_session, slice],
98+
[
99+
run,
100+
init,
101+
from_binary,
102+
to_binary,
103+
show_session,
104+
slice,
105+
reshape
106+
],
90107
load = |env: Env, _| {
91108
rustler::resource!(OrtexModel, env);
92109
rustler::resource!(OrtexTensor, env);

native/ortex/src/tensor.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,71 @@ impl OrtexTensor {
6363
}
6464
}
6565

66+
pub fn reshape(&self, shape: Vec<usize>) -> rustler::NifResult<Self> {
67+
match self {
68+
OrtexTensor::s8(y) => {
69+
Ok(OrtexTensor::s8(y.clone().into_shape(shape).map_err(
70+
|e| rustler::Error::Term(Box::new(e.to_string())),
71+
)?))
72+
}
73+
OrtexTensor::s16(y) => {
74+
Ok(OrtexTensor::s16(y.clone().into_shape(shape).map_err(
75+
|e| rustler::Error::Term(Box::new(e.to_string())),
76+
)?))
77+
}
78+
OrtexTensor::s32(y) => {
79+
Ok(OrtexTensor::s32(y.clone().into_shape(shape).map_err(
80+
|e| rustler::Error::Term(Box::new(e.to_string())),
81+
)?))
82+
}
83+
OrtexTensor::s64(y) => {
84+
Ok(OrtexTensor::s64(y.clone().into_shape(shape).map_err(
85+
|e| rustler::Error::Term(Box::new(e.to_string())),
86+
)?))
87+
}
88+
OrtexTensor::u8(y) => {
89+
Ok(OrtexTensor::u8(y.clone().into_shape(shape).map_err(
90+
|e| rustler::Error::Term(Box::new(e.to_string())),
91+
)?))
92+
}
93+
OrtexTensor::u16(y) => {
94+
Ok(OrtexTensor::u16(y.clone().into_shape(shape).map_err(
95+
|e| rustler::Error::Term(Box::new(e.to_string())),
96+
)?))
97+
}
98+
OrtexTensor::u32(y) => {
99+
Ok(OrtexTensor::u32(y.clone().into_shape(shape).map_err(
100+
|e| rustler::Error::Term(Box::new(e.to_string())),
101+
)?))
102+
}
103+
OrtexTensor::u64(y) => {
104+
Ok(OrtexTensor::u64(y.clone().into_shape(shape).map_err(
105+
|e| rustler::Error::Term(Box::new(e.to_string())),
106+
)?))
107+
}
108+
OrtexTensor::f16(y) => {
109+
Ok(OrtexTensor::f16(y.clone().into_shape(shape).map_err(
110+
|e| rustler::Error::Term(Box::new(e.to_string())),
111+
)?))
112+
}
113+
OrtexTensor::bf16(y) => {
114+
Ok(OrtexTensor::bf16(y.clone().into_shape(shape).map_err(
115+
|e| rustler::Error::Term(Box::new(e.to_string())),
116+
)?))
117+
}
118+
OrtexTensor::f32(y) => {
119+
Ok(OrtexTensor::f32(y.clone().into_shape(shape).map_err(
120+
|e| rustler::Error::Term(Box::new(e.to_string())),
121+
)?))
122+
}
123+
OrtexTensor::f64(y) => {
124+
Ok(OrtexTensor::f64(y.clone().into_shape(shape).map_err(
125+
|e| rustler::Error::Term(Box::new(e.to_string())),
126+
)?))
127+
}
128+
}
129+
}
130+
66131
pub fn dtype(&self) -> (Atom, usize) {
67132
match self {
68133
OrtexTensor::s8(_) => (ortex_atoms::s(), 8),

test/shape/reshape_test.exs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
defmodule Ortex.TestReshape do
2+
# TODO: Fix this, it is not truly validating the reshaping on the ortex side
3+
use ExUnit.Case
4+
5+
test "1d reshape" do
6+
t = Nx.tensor([1, 2, 3, 4])
7+
bin = t |> Nx.reshape({2, 2})
8+
9+
ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.reshape({2, 2}) |> Nx.backend_transfer()
10+
11+
assert bin == ort
12+
end
13+
14+
test "2d reshape" do
15+
shape = Nx.tensor([[0], [0], [0], [0]])
16+
t = Nx.tensor([1, 2, 3, 4])
17+
bin = t |> Nx.reshape(shape)
18+
19+
ort =
20+
t
21+
|> Nx.backend_copy(Ortex.Backend)
22+
|> Nx.reshape(shape |> Nx.backend_copy(Ortex.Backend))
23+
|> Nx.backend_transfer()
24+
25+
assert bin == ort
26+
end
27+
28+
test "scalar reshape" do
29+
shape = {1, 1, 1}
30+
t = Nx.tensor(1)
31+
bin = t |> Nx.reshape(shape)
32+
33+
ort =
34+
t
35+
|> Nx.backend_copy(Ortex.Backend)
36+
|> Nx.reshape(shape)
37+
|> Nx.backend_transfer()
38+
39+
assert bin == ort
40+
end
41+
42+
test "auto reshape" do
43+
shape = {:auto, 2}
44+
t = Nx.tensor([[1, 2, 3], [4, 5, 6]])
45+
bin = t |> Nx.reshape(shape)
46+
47+
ort =
48+
t
49+
|> Nx.backend_copy(Ortex.Backend)
50+
|> Nx.reshape(shape)
51+
|> Nx.backend_transfer()
52+
53+
assert bin == ort
54+
end
55+
end

test/shape/squeeze_test.exs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
defmodule Ortex.TestSqueeze do
2+
use ExUnit.Case
3+
4+
test "1d squeeze" do
5+
t = Nx.tensor([[[1, 2, 3, 4]]])
6+
bin = t |> Nx.squeeze()
7+
8+
ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze() |> Nx.backend_transfer()
9+
10+
assert bin == ort
11+
end
12+
13+
test "2d squeeze" do
14+
t = Nx.tensor([[[[1, 2]], [[3, 4]]]])
15+
bin = t |> Nx.squeeze()
16+
17+
ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze() |> Nx.backend_transfer()
18+
19+
assert bin == ort
20+
end
21+
22+
test "axis squeeze" do
23+
t = Nx.tensor([[[[1, 2]], [[3, 4]]]])
24+
bin = t |> Nx.squeeze(axes: [0])
25+
26+
ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze(axes: [0]) |> Nx.backend_transfer()
27+
28+
assert bin == ort
29+
end
30+
31+
test "named squeeze" do
32+
t = Nx.tensor([[[[1, 2]], [[3, 4]]]], names: [:w, :x, :y, :z])
33+
bin = t |> Nx.squeeze(axes: [:w])
34+
35+
ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze(axes: [:w]) |> Nx.backend_transfer()
36+
37+
assert bin == ort
38+
end
39+
end

0 commit comments

Comments
 (0)