Skip to content

Commit b5fc07c

Browse files
committed
added squeeze op
1 parent 38b8c26 commit b5fc07c

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

lib/ortex/backend.ex

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,23 @@ defmodule Ortex.Backend do
9090
put_in(out.data, %B{ref: Ortex.Native.reshape(ref, shape)})
9191
end
9292

93+
@impl true
94+
def squeeze(out, tensor, axes) do
95+
%T{shape: old_shape, names: names, data: %B{ref: ref}} = tensor
96+
{new_shape, new_names} = Nx.Shape.squeeze(old_shape, axes, names)
97+
98+
if old_shape == new_shape do
99+
%{out | data: %B{ref: ref}}
100+
else
101+
%{
102+
out
103+
| shape: new_shape,
104+
names: new_names,
105+
data: %B{ref: Ortex.Native.reshape(ref, new_shape |> Tuple.to_list())}
106+
}
107+
end
108+
end
109+
93110
if Application.compile_env(:ortex, :add_backend_on_inspect, true) do
94111
defp maybe_add_signature(result, %T{data: %B{ref: _mat_ref}}) do
95112
Inspect.Algebra.concat([

native/ortex/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,15 @@ pub fn slice<'a>(
8484
strides,
8585
)))
8686
}
87+
8788
#[rustler::nif]
8889
pub fn reshape<'a>(
8990
tensor: ResourceArc<OrtexTensor>,
9091
shape: Vec<usize>,
9192
) -> NifResult<ResourceArc<OrtexTensor>> {
9293
Ok(ResourceArc::new(tensor.reshape(shape)?))
9394
}
95+
9496
rustler::init!(
9597
"Elixir.Ortex.Native",
9698
[

test/shape/squeeze_test.exs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
defmodule Ortex.TestSqueeze do
2+
use ExUnit.Case
3+
4+
test "1d squeeze" do
5+
t = Nx.tensor([[[1, 2, 3, 4]]])
6+
bin = t |> Nx.squeeze()
7+
8+
ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze() |> Nx.backend_transfer()
9+
10+
assert bin == ort
11+
end
12+
13+
test "2d squeeze" do
14+
t = Nx.tensor([[[[1, 2]], [[3, 4]]]])
15+
bin = t |> Nx.squeeze()
16+
17+
ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze() |> Nx.backend_transfer()
18+
19+
assert bin == ort
20+
end
21+
22+
test "axis squeeze" do
23+
t = Nx.tensor([[[[1, 2]], [[3, 4]]]])
24+
bin = t |> Nx.squeeze(axes: [0])
25+
26+
ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze(axes: [0]) |> Nx.backend_transfer()
27+
28+
assert bin == ort
29+
end
30+
31+
test "named squeeze" do
32+
t = Nx.tensor([[[[1, 2]], [[3, 4]]]], names: [:w, :x, :y, :z])
33+
bin = t |> Nx.squeeze(axes: [:w])
34+
35+
ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze(axes: [:w]) |> Nx.backend_transfer()
36+
37+
assert bin == ort
38+
end
39+
end

0 commit comments

Comments
 (0)