Skip to content

Commit 7ad0b3d

Browse files
authored
Merge pull request #22 from gregszumel/concat_changes
adding concat functionality
2 parents 6ad9565 + 4fa5755 commit 7ad0b3d

File tree

5 files changed

+222
-2
lines changed

5 files changed

+222
-2
lines changed

lib/ortex/backend.ex

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,23 @@ defmodule Ortex.Backend do
107107
end
108108
end
109109

110+
@impl true
111+
def concatenate(out, tensors, axis) do
112+
if not Enum.all?(tensors, fn t -> t.type == out.type end) do
113+
raise "Ortex does not currently support concatenation of vectors with differing types."
114+
end
115+
116+
tensor_refs =
117+
Enum.map(tensors, fn t ->
118+
%T{data: %B{ref: ref}} = t
119+
ref
120+
end)
121+
122+
type = out.type
123+
124+
%{out | data: %B{ref: Ortex.Native.concatenate(tensor_refs, type, axis)}}
125+
end
126+
110127
if Application.compile_env(:ortex, :add_backend_on_inspect, true) do
111128
defp maybe_add_signature(result, %T{data: %B{ref: _mat_ref}}) do
112129
Inspect.Algebra.concat([

lib/ortex/native.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,6 @@ defmodule Ortex.Native do
2626
do: :erlang.nif_error(:nif_not_loaded)
2727

2828
def reshape(_tensor, _shape), do: :erlang.nif_error(:nif_not_loaded)
29+
30+
def concatenate(_tensors_refs, _type, _axis), do: :erlang.nif_error(:nif_not_loaded)
2931
end

native/ortex/src/lib.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ pub fn reshape<'a>(
9393
Ok(ResourceArc::new(tensor.reshape(shape)?))
9494
}
9595

96+
#[rustler::nif]
97+
pub fn concatenate<'a>(
98+
tensors: Vec<ResourceArc<OrtexTensor>>,
99+
dtype: Term,
100+
axis: i32,
101+
) -> NifResult<ResourceArc<OrtexTensor>> {
102+
let (dtype_t, dtype_bits): (Term, usize) = dtype.decode()?;
103+
let dtype_str = dtype_t.atom_to_string()?;
104+
let concatted = tensor::concatenate(tensors, (&dtype_str, dtype_bits), axis as usize);
105+
Ok(ResourceArc::new(concatted))
106+
}
107+
96108
rustler::init!(
97109
"Elixir.Ortex.Native",
98110
[
@@ -102,7 +114,8 @@ rustler::init!(
102114
to_binary,
103115
show_session,
104116
slice,
105-
reshape
117+
reshape,
118+
concatenate
106119
],
107120
load = |env: Env, _| {
108121
rustler::resource!(OrtexModel, env);

native/ortex/src/tensor.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
//! Conversions for packing/unpacking `OrtexTensor`s into different types
22
use ndarray::prelude::*;
3-
use ndarray::{ArrayBase, ArrayView, Data, IxDyn};
3+
use ndarray::{ArrayBase, ArrayView, Data, IxDyn, IxDynImpl, ViewRepr};
44
use ort::tensor::{DynOrtTensor, FromArray, InputTensor, TensorElementDataType};
55
use ort::OrtError;
6+
use rustler::resource::ResourceArc;
67
use rustler::Atom;
78

89
use crate::constants::ortex_atoms;
@@ -277,3 +278,58 @@ impl std::convert::TryFrom<&DynOrtTensor<'_, IxDyn>> for OrtexTensor {
277278
}
278279
}
279280
}
281+
282+
// Currently only supports concatenating tenors of the same type.
283+
//
284+
// This is a similar structure to the above match clauses, except each function
285+
// in map is more complex and needs to be written out explicitly. To reduce
286+
// repetition, the concatenate! macro expands that code and makes the necessary
287+
// minor tweaks
288+
289+
macro_rules! concatenate {
290+
// `typ` is the actual datatype, `ort_tensor_kind` is the OrtexTensor variant
291+
($tensors:expr, $axis:expr, $typ:ty, $ort_tensor_kind:ident) => {{
292+
type ArrayType<'a> = ArrayBase<ViewRepr<&'a $typ>, Dim<IxDynImpl>>;
293+
fn filter(tensor: &OrtexTensor) -> Option<ArrayType> {
294+
match tensor {
295+
OrtexTensor::$ort_tensor_kind(x) => Some(x.view()),
296+
_ => None,
297+
}
298+
}
299+
// hack way to type coalesce. Filters out any ndarray's that don't
300+
// have the desired type
301+
let tensors: Vec<ArrayType> = $tensors
302+
.iter()
303+
.filter_map(|tensor| filter(tensor))
304+
.collect();
305+
306+
let tensors = ndarray::concatenate(Axis($axis), &tensors).unwrap();
307+
// data is not contiguous after the concatenation above. To decode
308+
// properly, need to create a new contiguous vector
309+
let tensors =
310+
Array::from_shape_vec(tensors.raw_dim(), tensors.iter().cloned().collect()).unwrap();
311+
OrtexTensor::$ort_tensor_kind(tensors)
312+
}};
313+
}
314+
315+
pub fn concatenate(
316+
tensors: Vec<ResourceArc<OrtexTensor>>,
317+
dtype: (&str, usize),
318+
axis: usize,
319+
) -> OrtexTensor {
320+
match dtype {
321+
("s", 8) => concatenate!(tensors, axis, i8, s8),
322+
("s", 16) => concatenate!(tensors, axis, i16, s16),
323+
("s", 32) => concatenate!(tensors, axis, i32, s32),
324+
("s", 64) => concatenate!(tensors, axis, i64, s64),
325+
("u", 8) => concatenate!(tensors, axis, u8, u8),
326+
("u", 16) => concatenate!(tensors, axis, u16, u16),
327+
("u", 32) => concatenate!(tensors, axis, u32, u32),
328+
("u", 64) => concatenate!(tensors, axis, u64, u64),
329+
("f", 16) => concatenate!(tensors, axis, half::f16, f16),
330+
("bf", 16) => concatenate!(tensors, axis, half::bf16, bf16),
331+
("f", 32) => concatenate!(tensors, axis, f32, f32),
332+
("f", 64) => concatenate!(tensors, axis, f64, f64),
333+
_ => unimplemented!(),
334+
}
335+
}

test/shape/concat_test.exs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
defmodule Ortex.TestConcat do
2+
use ExUnit.Case
3+
4+
# Testing each type, since there's a bunch of boilerplate that we want to
5+
# check for errors on the Rust side
6+
%{
7+
"s8" => {:s, 8},
8+
"s16" => {:s, 16},
9+
"s32" => {:s, 16},
10+
"s64" => {:s, 16},
11+
"u8" => {:s, 16},
12+
"u16" => {:s, 16},
13+
"u32" => {:s, 16},
14+
"u64" => {:s, 16},
15+
"f16" => {:s, 16},
16+
"bf16" => {:s, 16},
17+
"f32" => {:s, 16},
18+
"f64" => {:s, 16}
19+
}
20+
|> Enum.each(fn {type_str, type_tuple} ->
21+
test "Concat 1d tensors #{type_str}" do
22+
t1 =
23+
Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
24+
25+
t2 =
26+
Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
27+
28+
concatted = Nx.concatenate([t1, t2]) |> Nx.backend_transfer()
29+
expected = Nx.tensor([1, 2, 3, 4, 1, 2, 3, 4], type: unquote(type_tuple))
30+
assert concatted == expected
31+
end
32+
33+
test "Concat 3d tensors #{type_str}" do
34+
o1 = Nx.iota({2, 3, 5}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
35+
o2 = Nx.iota({1, 3, 5}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
36+
concatted = Nx.concatenate([o1, o2]) |> Nx.backend_transfer()
37+
expected = Nx.concatenate([o1 |> Nx.backend_transfer(), o2 |> Nx.backend_transfer()])
38+
assert concatted == expected
39+
end
40+
41+
test "Concat 3 #{type_str} vectors" do
42+
t1 =
43+
Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
44+
45+
t2 =
46+
Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
47+
48+
t3 =
49+
Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
50+
51+
concatted = Nx.concatenate([t1, t2, t3]) |> Nx.backend_transfer()
52+
expected = Nx.tensor([1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], type: unquote(type_tuple))
53+
assert concatted == expected
54+
end
55+
56+
test "Concat axis #{type_str} 1" do
57+
o1 = Nx.iota({3, 5}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
58+
o2 = Nx.iota({3, 5}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
59+
60+
concatted = Nx.concatenate([o1, o2], axis: 1) |> Nx.backend_transfer()
61+
62+
n1 = Nx.iota({3, 5}, type: unquote(type_tuple))
63+
n2 = Nx.iota({3, 5}, type: unquote(type_tuple))
64+
65+
expected = Nx.concatenate([n1, n2], axis: 1)
66+
assert concatted == expected
67+
end
68+
69+
test "Concat axis 1 of three 3-dimensional #{type_str} vector" do
70+
t1 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
71+
t2 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
72+
t3 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
73+
74+
concatted = Nx.concatenate([t1, t2, t3], axis: 1) |> Nx.backend_transfer()
75+
76+
expected =
77+
Nx.concatenate(
78+
[
79+
Nx.iota({3, 5, 7}, type: unquote(type_tuple)),
80+
Nx.iota({3, 5, 7}, type: unquote(type_tuple)),
81+
Nx.iota({3, 5, 7}, type: unquote(type_tuple))
82+
],
83+
axis: 1
84+
)
85+
86+
assert concatted == expected
87+
end
88+
89+
test "Concat axis 2 of three 3-dimensional #{type_str} vector" do
90+
t1 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
91+
t2 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
92+
t3 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
93+
94+
concatted = Nx.concatenate([t1, t2, t3], axis: 2) |> Nx.backend_transfer()
95+
96+
expected =
97+
Nx.concatenate(
98+
[
99+
Nx.iota({3, 5, 7}, type: unquote(type_tuple)),
100+
Nx.iota({3, 5, 7}, type: unquote(type_tuple)),
101+
Nx.iota({3, 5, 7}, type: unquote(type_tuple))
102+
],
103+
axis: 2
104+
)
105+
106+
assert concatted == expected
107+
end
108+
109+
test "Concat doesn't alter component #{type_str} vectors" do
110+
t1 =
111+
Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
112+
113+
t2 =
114+
Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend)
115+
116+
concatted = Nx.concatenate([t1, t2]) |> Nx.backend_transfer()
117+
second_concatted = Nx.concatenate([t1, t2]) |> Nx.backend_transfer()
118+
119+
assert concatted == second_concatted
120+
end
121+
end)
122+
123+
test "Concat fails to concat vectors of differing types" do
124+
assert_raise RuntimeError,
125+
"Ortex does not currently support concatenation of vectors with differing types.",
126+
fn ->
127+
t1 = Nx.tensor([1, 2, 3], type: {:s, 16}) |> Nx.backend_transfer(Ortex.Backend)
128+
t2 = Nx.tensor([1, 2, 3], type: {:s, 32}) |> Nx.backend_transfer(Ortex.Backend)
129+
_err = Nx.concatenate([t1, t2])
130+
end
131+
end
132+
end

0 commit comments

Comments
 (0)