Skip to content

Commit 38b8c26

Browse files
committed
actual reshape implementation
1 parent b72c5dc commit 38b8c26

File tree

5 files changed

+88
-5
lines changed

5 files changed

+88
-5
lines changed

lib/ortex/backend.ex

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ defmodule Ortex.Backend do
7171
@impl true
7272
def inspect(%T{} = tensor, inspect_opts) do
7373
limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1
74-
# IO.inspect(limit)
7574

7675
tensor
7776
|> to_binary(min(limit, Nx.size(tensor)))
@@ -82,12 +81,13 @@ defmodule Ortex.Backend do
8281
@impl true
8382
def slice(out, %T{data: %B{ref: tensor_ref}}, start_indicies, lengths, strides) do
8483
r = Ortex.Native.slice(tensor_ref, start_indicies, lengths, strides)
85-
put_in(out.data, %Ortex.Backend{ref: r})
84+
put_in(out.data, %B{ref: r})
8685
end
8786

8887
@impl true
89-
def reshape(out, _tensor) do
90-
out
88+
def reshape(out, %T{data: %B{ref: ref}}) do
89+
shape = Nx.shape(out) |> Tuple.to_list()
90+
put_in(out.data, %B{ref: Ortex.Native.reshape(ref, shape)})
9191
end
9292

9393
if Application.compile_env(:ortex, :add_backend_on_inspect, true) do

lib/ortex/native.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,6 @@ defmodule Ortex.Native do
2424

2525
def slice(_tensor, _start_indicies, _lengths, _strides),
2626
do: :erlang.nif_error(:nif_not_loaded)
27+
28+
def reshape(_tensor, _shape), do: :erlang.nif_error(:nif_not_loaded)
2729
end

native/ortex/src/lib.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,24 @@ pub fn slice<'a>(
8484
strides,
8585
)))
8686
}
87+
#[rustler::nif]
88+
pub fn reshape<'a>(
89+
tensor: ResourceArc<OrtexTensor>,
90+
shape: Vec<usize>,
91+
) -> NifResult<ResourceArc<OrtexTensor>> {
92+
Ok(ResourceArc::new(tensor.reshape(shape)?))
93+
}
8794
rustler::init!(
8895
"Elixir.Ortex.Native",
89-
[run, init, from_binary, to_binary, show_session, slice],
96+
[
97+
run,
98+
init,
99+
from_binary,
100+
to_binary,
101+
show_session,
102+
slice,
103+
reshape
104+
],
90105
load = |env: Env, _| {
91106
rustler::resource!(OrtexModel, env);
92107
rustler::resource!(OrtexTensor, env);

native/ortex/src/tensor.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,71 @@ impl OrtexTensor {
6363
}
6464
}
6565

66+
pub fn reshape(&self, shape: Vec<usize>) -> rustler::NifResult<Self> {
67+
match self {
68+
OrtexTensor::s8(y) => {
69+
Ok(OrtexTensor::s8(y.clone().into_shape(shape).map_err(
70+
|e| rustler::Error::Term(Box::new(e.to_string())),
71+
)?))
72+
}
73+
OrtexTensor::s16(y) => {
74+
Ok(OrtexTensor::s16(y.clone().into_shape(shape).map_err(
75+
|e| rustler::Error::Term(Box::new(e.to_string())),
76+
)?))
77+
}
78+
OrtexTensor::s32(y) => {
79+
Ok(OrtexTensor::s32(y.clone().into_shape(shape).map_err(
80+
|e| rustler::Error::Term(Box::new(e.to_string())),
81+
)?))
82+
}
83+
OrtexTensor::s64(y) => {
84+
Ok(OrtexTensor::s64(y.clone().into_shape(shape).map_err(
85+
|e| rustler::Error::Term(Box::new(e.to_string())),
86+
)?))
87+
}
88+
OrtexTensor::u8(y) => {
89+
Ok(OrtexTensor::u8(y.clone().into_shape(shape).map_err(
90+
|e| rustler::Error::Term(Box::new(e.to_string())),
91+
)?))
92+
}
93+
OrtexTensor::u16(y) => {
94+
Ok(OrtexTensor::u16(y.clone().into_shape(shape).map_err(
95+
|e| rustler::Error::Term(Box::new(e.to_string())),
96+
)?))
97+
}
98+
OrtexTensor::u32(y) => {
99+
Ok(OrtexTensor::u32(y.clone().into_shape(shape).map_err(
100+
|e| rustler::Error::Term(Box::new(e.to_string())),
101+
)?))
102+
}
103+
OrtexTensor::u64(y) => {
104+
Ok(OrtexTensor::u64(y.clone().into_shape(shape).map_err(
105+
|e| rustler::Error::Term(Box::new(e.to_string())),
106+
)?))
107+
}
108+
OrtexTensor::f16(y) => {
109+
Ok(OrtexTensor::f16(y.clone().into_shape(shape).map_err(
110+
|e| rustler::Error::Term(Box::new(e.to_string())),
111+
)?))
112+
}
113+
OrtexTensor::bf16(y) => {
114+
Ok(OrtexTensor::bf16(y.clone().into_shape(shape).map_err(
115+
|e| rustler::Error::Term(Box::new(e.to_string())),
116+
)?))
117+
}
118+
OrtexTensor::f32(y) => {
119+
Ok(OrtexTensor::f32(y.clone().into_shape(shape).map_err(
120+
|e| rustler::Error::Term(Box::new(e.to_string())),
121+
)?))
122+
}
123+
OrtexTensor::f64(y) => {
124+
Ok(OrtexTensor::f64(y.clone().into_shape(shape).map_err(
125+
|e| rustler::Error::Term(Box::new(e.to_string())),
126+
)?))
127+
}
128+
}
129+
}
130+
66131
pub fn dtype(&self) -> (Atom, usize) {
67132
match self {
68133
OrtexTensor::s8(_) => (ortex_atoms::s(), 8),

test/shape/reshape_test.exs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
defmodule Ortex.TestReshape do
2+
# TODO: Fix this, it is not truly validating the reshaping on the ortex side
23
use ExUnit.Case
34

45
test "1d reshape" do

0 commit comments

Comments
 (0)