Skip to content

Commit b72c5dc

Browse files
committed
added reshape to Ortex.Backend
1 parent 88626bb commit b72c5dc

File tree

3 files changed

+59
-0
lines changed

3 files changed

+59
-0
lines changed

lib/ortex/backend.ex

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ defmodule Ortex.Backend do
8585
put_in(out.data, %Ortex.Backend{ref: r})
8686
end
8787

88+
@impl true
89+
def reshape(out, _tensor) do
90+
out
91+
end
92+
8893
if Application.compile_env(:ortex, :add_backend_on_inspect, true) do
8994
defp maybe_add_signature(result, %T{data: %B{ref: _mat_ref}}) do
9095
Inspect.Algebra.concat([

test/shape/reshape_test.exs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
defmodule Ortex.TestReshape do
2+
use ExUnit.Case
3+
4+
test "1d reshape" do
5+
t = Nx.tensor([1, 2, 3, 4])
6+
bin = t |> Nx.reshape({2, 2})
7+
8+
ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.reshape({2, 2}) |> Nx.backend_transfer()
9+
10+
assert bin == ort
11+
end
12+
13+
test "2d reshape" do
14+
shape = Nx.tensor([[0], [0], [0], [0]])
15+
t = Nx.tensor([1, 2, 3, 4])
16+
bin = t |> Nx.reshape(shape)
17+
18+
ort =
19+
t
20+
|> Nx.backend_copy(Ortex.Backend)
21+
|> Nx.reshape(shape |> Nx.backend_copy(Ortex.Backend))
22+
|> Nx.backend_transfer()
23+
24+
assert bin == ort
25+
end
26+
27+
test "scalar reshape" do
28+
shape = {1, 1, 1}
29+
t = Nx.tensor(1)
30+
bin = t |> Nx.reshape(shape)
31+
32+
ort =
33+
t
34+
|> Nx.backend_copy(Ortex.Backend)
35+
|> Nx.reshape(shape)
36+
|> Nx.backend_transfer()
37+
38+
assert bin == ort
39+
end
40+
41+
test "auto reshape" do
42+
shape = {:auto, 2}
43+
t = Nx.tensor([[1, 2, 3], [4, 5, 6]])
44+
bin = t |> Nx.reshape(shape)
45+
46+
ort =
47+
t
48+
|> Nx.backend_copy(Ortex.Backend)
49+
|> Nx.reshape(shape)
50+
|> Nx.backend_transfer()
51+
52+
assert bin == ort
53+
end
54+
end

0 commit comments

Comments
 (0)