Skip to content

Commit 935cde3

Browse files
authored
Merge pull request #9 from relaypro-open/develop
implemented `slice` on `Ortex.Backend`
2 parents 6770469 + 1d69995 commit 935cde3

File tree

5 files changed

+166
-3
lines changed

5 files changed

+166
-3
lines changed

lib/ortex/backend.ex

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ defmodule Ortex.Backend do
7979
|> maybe_add_signature(tensor)
8080
end
8181

82+
@impl true
83+
def slice(out, %T{data: %B{ref: tensor_ref}}, start_indicies, lengths, strides) do
84+
r = Ortex.Native.slice(tensor_ref, start_indicies, lengths, strides)
85+
put_in(out.data, %Ortex.Backend{ref: r})
86+
end
87+
8288
if Application.compile_env(:ortex, :add_backend_on_inspect, true) do
8389
defp maybe_add_signature(result, %T{data: %B{ref: _mat_ref}}) do
8490
Inspect.Algebra.concat([

lib/ortex/native.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,7 @@ defmodule Ortex.Native do
2121
def from_binary(_bin, _shape, _type), do: :erlang.nif_error(:nif_not_loaded)
2222
def to_binary(_reference, _bits, _limit), do: :erlang.nif_error(:nif_not_loaded)
2323
def show_session(_model), do: :erlang.nif_error(:nif_not_loaded)
24+
25+
def slice(_tensor, _start_indicies, _lengths, _strides),
26+
do: :erlang.nif_error(:nif_not_loaded)
2427
end

native/ortex/src/lib.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,22 @@ fn to_binary<'a>(
7171
utils::to_binary(env, reference, bits, limit)
7272
}
7373

74+
#[rustler::nif]
75+
pub fn slice<'a>(
76+
tensor: ResourceArc<OrtexTensor>,
77+
start_indicies: Vec<isize>,
78+
lengths: Vec<isize>,
79+
strides: Vec<isize>,
80+
) -> NifResult<ResourceArc<OrtexTensor>> {
81+
Ok(ResourceArc::new(tensor.slice(
82+
start_indicies,
83+
lengths,
84+
strides,
85+
)))
86+
}
7487
rustler::init!(
7588
"Elixir.Ortex.Native",
76-
[run, init, from_binary, to_binary, show_session],
89+
[run, init, from_binary, to_binary, show_session, slice],
7790
load = |env: Env, _| {
7891
rustler::resource!(OrtexModel, env);
7992
rustler::resource!(OrtexTensor, env);

native/ortex/src/tensor.rs

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Conversions for packing/unpacking `OrtexTensor`s into different types
22
use ndarray::prelude::*;
3-
use ndarray::Data;
3+
use ndarray::{ArrayBase, ArrayView, Data, IxDyn};
44
use ort::tensor::{DynOrtTensor, FromArray, InputTensor, TensorElementDataType};
55
use ort::OrtError;
66
use rustler::Atom;
@@ -9,7 +9,8 @@ use crate::constants::ortex_atoms;
99

1010
#[derive(Debug)]
1111
#[allow(non_camel_case_types)]
12-
/// Enum for wrapping different types to pass back to the BEAM
12+
/// Enum for wrapping different types to pass back to the BEAM since rustler can't
13+
/// pass type generics back and forth
1314
pub enum OrtexTensor {
1415
s8(Array<i8, IxDyn>),
1516
s16(Array<i16, IxDyn>),
@@ -96,6 +97,49 @@ impl OrtexTensor {
9697
};
9798
contents
9899
}
100+
101+
pub fn slice<'a>(
102+
&'a self,
103+
start_indicies: Vec<isize>,
104+
lengths: Vec<isize>,
105+
strides: Vec<isize>,
106+
) -> Self {
107+
let mut slice_specs: Vec<(isize, Option<isize>, isize)> = vec![];
108+
for ((start_index, length), stride) in start_indicies
109+
.iter()
110+
.zip(lengths.iter())
111+
.zip(strides.iter())
112+
{
113+
slice_specs.push((*start_index, Some(*length + *start_index), *stride));
114+
}
115+
match self {
116+
OrtexTensor::s8(y) => OrtexTensor::s8(slice_array(y, &slice_specs).to_owned()),
117+
OrtexTensor::s16(y) => OrtexTensor::s16(slice_array(y, &slice_specs).to_owned()),
118+
OrtexTensor::s32(y) => OrtexTensor::s32(slice_array(y, &slice_specs).to_owned()),
119+
OrtexTensor::s64(y) => OrtexTensor::s64(slice_array(y, &slice_specs).to_owned()),
120+
OrtexTensor::u8(y) => OrtexTensor::u8(slice_array(y, &slice_specs).to_owned()),
121+
OrtexTensor::u16(y) => OrtexTensor::u16(slice_array(y, &slice_specs).to_owned()),
122+
OrtexTensor::u32(y) => OrtexTensor::u32(slice_array(y, &slice_specs).to_owned()),
123+
OrtexTensor::u64(y) => OrtexTensor::u64(slice_array(y, &slice_specs).to_owned()),
124+
OrtexTensor::f16(y) => OrtexTensor::f16(slice_array(y, &slice_specs).to_owned()),
125+
OrtexTensor::bf16(y) => OrtexTensor::bf16(slice_array(y, &slice_specs).to_owned()),
126+
OrtexTensor::f32(y) => OrtexTensor::f32(slice_array(y, &slice_specs).to_owned()),
127+
OrtexTensor::f64(y) => OrtexTensor::f64(slice_array(y, &slice_specs).to_owned()),
128+
}
129+
}
130+
}
131+
132+
fn slice_array<'a, T, D>(
133+
array: &'a Array<T, D>,
134+
slice_specs: &'a Vec<(isize, Option<isize>, isize)>,
135+
) -> ArrayView<'a, T, D>
136+
where
137+
D: Dimension,
138+
{
139+
array.slice_each_axis(|ax: ndarray::AxisDescription| {
140+
let (start, end, step) = slice_specs[ax.axis.index()];
141+
ndarray::Slice { start, end, step }
142+
})
99143
}
100144

101145
fn get_bytes<'a, T>(array: &'a ArrayBase<T, IxDyn>) -> &'a [u8]

test/slice/slice_test.exs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
defmodule Ortex.TestSlice do
2+
use ExUnit.Case
3+
4+
{tensor1d, _} = Nx.Random.uniform(Nx.Random.key(42), 0, 256, shape: {10})
5+
{tensor2d, _} = Nx.Random.uniform(Nx.Random.key(42), 0, 256, shape: {10, 10})
6+
7+
@tensor1d tensor1d
8+
@tensor2d tensor2d
9+
10+
defp tensor_binary(tensor, dtype) do
11+
tensor |> Nx.as_type(dtype)
12+
end
13+
14+
defp tensor_ortex(tensor, dtype) do
15+
tensor
16+
|> Nx.as_type(dtype)
17+
|> Nx.backend_transfer(Ortex.Backend)
18+
end
19+
20+
test "1d slice f32" do
21+
bin = tensor_binary(@tensor1d, :f32) |> Nx.slice([0], [4])
22+
23+
ort = tensor_ortex(@tensor1d, :f32) |> Nx.slice([0], [4]) |> Nx.backend_transfer()
24+
25+
assert bin == ort
26+
end
27+
28+
test "2d slice f32" do
29+
bin = tensor_binary(@tensor2d, :f32) |> Nx.slice([0, 2], [4, 6])
30+
31+
ort =
32+
tensor_ortex(@tensor2d, :f32)
33+
|> Nx.slice([0, 2], [4, 6])
34+
|> Nx.backend_transfer()
35+
36+
assert bin == ort
37+
end
38+
39+
test "1d slice u8" do
40+
bin = tensor_binary(@tensor1d, :u8) |> Nx.slice([0], [4])
41+
42+
ort = tensor_ortex(@tensor1d, :u8) |> Nx.slice([0], [4]) |> Nx.backend_transfer()
43+
44+
assert bin == ort
45+
end
46+
47+
test "2d slice u8" do
48+
bin = tensor_binary(@tensor2d, :u8) |> Nx.slice([0, 2], [4, 6])
49+
50+
ort =
51+
tensor_ortex(@tensor2d, :u8)
52+
|> Nx.slice([0, 2], [4, 6])
53+
|> Nx.backend_transfer()
54+
55+
assert bin == ort
56+
end
57+
58+
test "1d slice f32 strided" do
59+
bin = tensor_binary(@tensor1d, :f32) |> Nx.slice([0], [4], strides: [2])
60+
61+
ort =
62+
tensor_ortex(@tensor1d, :f32) |> Nx.slice([0], [4], strides: [2]) |> Nx.backend_transfer()
63+
64+
assert bin == ort
65+
end
66+
67+
test "2d slice f32 strided" do
68+
bin = tensor_binary(@tensor2d, :f32) |> Nx.slice([0, 2], [4, 6], strides: [2, 1])
69+
70+
ort =
71+
tensor_ortex(@tensor2d, :f32)
72+
|> Nx.slice([0, 2], [4, 6], strides: [2, 1])
73+
|> Nx.backend_transfer()
74+
75+
assert bin == ort
76+
end
77+
78+
test "1d slice u8 strided" do
79+
bin = tensor_binary(@tensor1d, :u8) |> Nx.slice([0], [4], strides: [2])
80+
81+
ort =
82+
tensor_ortex(@tensor1d, :u8) |> Nx.slice([0], [4], strides: [2]) |> Nx.backend_transfer()
83+
84+
assert bin == ort
85+
end
86+
87+
test "2d slice u8 strided" do
88+
bin = tensor_binary(@tensor2d, :u8) |> Nx.slice([0, 2], [4, 6], strides: [2, 1])
89+
90+
ort =
91+
tensor_ortex(@tensor2d, :u8)
92+
|> Nx.slice([0, 2], [4, 6], strides: [2, 1])
93+
|> Nx.backend_transfer()
94+
95+
assert bin == ort
96+
end
97+
end

0 commit comments

Comments
 (0)