Skip to content

Commit 234b3f9

Browse files
author
Greg Szumel
committed
adding concat functionality
1 parent 6ad9565 commit 234b3f9

File tree

5 files changed

+437
-1
lines changed

5 files changed

+437
-1
lines changed

lib/ortex/backend.ex

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,23 @@ defmodule Ortex.Backend do
107107
end
108108
end
109109

110+
@impl true
111+
def concatenate(out, tensors, axis) do
112+
if not Enum.all?(tensors, fn t -> t.type == out.type end) do
113+
raise "Ortex does not currently support concatenation of vectors with differing types."
114+
end
115+
116+
tensor_refs =
117+
Enum.map(tensors, fn t ->
118+
%T{data: %B{ref: ref}} = t
119+
ref
120+
end)
121+
122+
type = out.type
123+
124+
%{out | data: %B{ref: Ortex.Native.concatenate(tensor_refs, type, axis)}}
125+
end
126+
110127
if Application.compile_env(:ortex, :add_backend_on_inspect, true) do
111128
defp maybe_add_signature(result, %T{data: %B{ref: _mat_ref}}) do
112129
Inspect.Algebra.concat([

lib/ortex/native.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,6 @@ defmodule Ortex.Native do
2626
do: :erlang.nif_error(:nif_not_loaded)
2727

2828
def reshape(_tensor, _shape), do: :erlang.nif_error(:nif_not_loaded)
29+
30+
def concatenate(_tensors_refs, _type, _axis), do: :erlang.nif_error(:nif_not_loaded)
2931
end

native/ortex/src/lib.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ pub fn reshape<'a>(
9393
Ok(ResourceArc::new(tensor.reshape(shape)?))
9494
}
9595

96+
#[rustler::nif]
97+
pub fn concatenate<'a>(
98+
tensors: Vec<ResourceArc<OrtexTensor>>,
99+
dtype: Term,
100+
axis: i32,
101+
) -> NifResult<ResourceArc<OrtexTensor>> {
102+
let (dtype_t, dtype_bits): (Term, usize) = dtype.decode()?;
103+
let dtype_str = dtype_t.atom_to_string()?;
104+
let concatted = tensor::concatenate(tensors, (&dtype_str, dtype_bits), axis as usize);
105+
Ok(ResourceArc::new(concatted))
106+
}
107+
96108
rustler::init!(
97109
"Elixir.Ortex.Native",
98110
[
@@ -102,7 +114,8 @@ rustler::init!(
102114
to_binary,
103115
show_session,
104116
slice,
105-
reshape
117+
reshape,
118+
concatenate
106119
],
107120
load = |env: Env, _| {
108121
rustler::resource!(OrtexModel, env);

native/ortex/src/tensor.rs

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use ndarray::prelude::*;
33
use ndarray::{ArrayBase, ArrayView, Data, IxDyn};
44
use ort::tensor::{DynOrtTensor, FromArray, InputTensor, TensorElementDataType};
55
use ort::OrtError;
6+
use rustler::resource::ResourceArc;
67
use rustler::Atom;
78

89
use crate::constants::ortex_atoms;
@@ -277,3 +278,256 @@ impl std::convert::TryFrom<&DynOrtTensor<'_, IxDyn>> for OrtexTensor {
277278
}
278279
}
279280
}
281+
282+
// Currently only supports concatenating tenors of the same type.
283+
//
284+
// This is a similar structure to the above match clauses, except each function
285+
// in map is more complex and needs to be written out explicitly, see below.
286+
//
287+
// Each fn concatenate_{type} verifies to the compiler that the vec<OrtexTensor>
288+
// all have the same type, and then we can concat easily from there
289+
//
290+
// TODO: make the fn concatenate_{type} a macro?
291+
pub fn concatenate(
292+
tensors: Vec<ResourceArc<OrtexTensor>>,
293+
dtype: (&str, usize),
294+
axis: usize,
295+
) -> OrtexTensor {
296+
match dtype {
297+
("s", 8) => concatenate_s8(tensors, axis),
298+
("s", 16) => concatenate_s16(tensors, axis),
299+
("s", 32) => concatenate_s32(tensors, axis),
300+
("s", 64) => concatenate_s64(tensors, axis),
301+
("u", 8) => concatenate_u8(tensors, axis),
302+
("u", 16) => concatenate_u16(tensors, axis),
303+
("u", 32) => concatenate_u32(tensors, axis),
304+
("u", 64) => concatenate_u64(tensors, axis),
305+
("f", 16) => concatenate_f16(tensors, axis),
306+
("bf", 16) => concatenate_bf16(tensors, axis),
307+
("f", 32) => concatenate_f32(tensors, axis),
308+
("f", 64) => concatenate_f64(tensors, axis),
309+
_ => unimplemented!(),
310+
}
311+
}
312+
313+
// each of the below concatenate_{x} functions are identical except for the
314+
// underlying data-type / OrtexTensor enum
315+
fn concatenate_s8(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
316+
// very hacky way to type coalesce, filter_map using an option
317+
fn filter_s8(
318+
of: &OrtexTensor,
319+
) -> Option<ArrayBase<ndarray::ViewRepr<&i8>, Dim<ndarray::IxDynImpl>>> {
320+
match of {
321+
OrtexTensor::s8(x) => Some(x.view()),
322+
_ => None,
323+
}
324+
}
325+
326+
// now all tensors have the same type after filter_map()-ing
327+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i8>, Dim<ndarray::IxDynImpl>>> =
328+
tensors.iter().filter_map(|val| filter_s8(val)).collect();
329+
330+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
331+
332+
// because concatenating creates a non-standard data format, we copy the
333+
// data into a standard format shape. Otherwise, when converting to a
334+
// binary, the tensor's data is not ordered properly
335+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
336+
OrtexTensor::s8(x)
337+
}
338+
339+
fn concatenate_s16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
340+
fn filter_s16(
341+
of: &OrtexTensor,
342+
) -> Option<ArrayBase<ndarray::ViewRepr<&i16>, Dim<ndarray::IxDynImpl>>> {
343+
match of {
344+
OrtexTensor::s16(x) => Some(x.view()),
345+
_ => None,
346+
}
347+
}
348+
349+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i16>, Dim<ndarray::IxDynImpl>>> =
350+
tensors.iter().filter_map(|val| filter_s16(val)).collect();
351+
352+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
353+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
354+
OrtexTensor::s16(x)
355+
}
356+
357+
fn concatenate_s32(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
358+
fn filter_s32(
359+
of: &OrtexTensor,
360+
) -> Option<ArrayBase<ndarray::ViewRepr<&i32>, Dim<ndarray::IxDynImpl>>> {
361+
match of {
362+
OrtexTensor::s32(x) => Some(x.view()),
363+
_ => None,
364+
}
365+
}
366+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i32>, Dim<ndarray::IxDynImpl>>> =
367+
tensors.iter().filter_map(|val| filter_s32(val)).collect();
368+
369+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
370+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
371+
OrtexTensor::s32(x)
372+
}
373+
374+
fn concatenate_s64(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
375+
fn filter_s64(
376+
of: &OrtexTensor,
377+
) -> Option<ArrayBase<ndarray::ViewRepr<&i64>, Dim<ndarray::IxDynImpl>>> {
378+
match of {
379+
OrtexTensor::s64(x) => Some(x.view()),
380+
_ => None,
381+
}
382+
}
383+
384+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i64>, Dim<ndarray::IxDynImpl>>> =
385+
tensors.iter().filter_map(|val| filter_s64(val)).collect();
386+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
387+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
388+
OrtexTensor::s64(x)
389+
}
390+
391+
fn concatenate_u8(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
392+
fn filter_u8(
393+
of: &OrtexTensor,
394+
) -> Option<ArrayBase<ndarray::ViewRepr<&u8>, Dim<ndarray::IxDynImpl>>> {
395+
match of {
396+
OrtexTensor::u8(x) => Some(x.view()),
397+
_ => None,
398+
}
399+
}
400+
401+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u8>, Dim<ndarray::IxDynImpl>>> =
402+
tensors.iter().filter_map(|val| filter_u8(val)).collect();
403+
404+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
405+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
406+
OrtexTensor::u8(x)
407+
}
408+
409+
fn concatenate_u16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
410+
fn filter_u16(
411+
of: &OrtexTensor,
412+
) -> Option<ArrayBase<ndarray::ViewRepr<&u16>, Dim<ndarray::IxDynImpl>>> {
413+
match of {
414+
OrtexTensor::u16(x) => Some(x.view()),
415+
_ => None,
416+
}
417+
}
418+
419+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u16>, Dim<ndarray::IxDynImpl>>> =
420+
tensors.iter().filter_map(|val| filter_u16(val)).collect();
421+
422+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
423+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
424+
OrtexTensor::u16(x)
425+
}
426+
427+
fn concatenate_u32(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
428+
fn filter_u32(
429+
of: &OrtexTensor,
430+
) -> Option<ArrayBase<ndarray::ViewRepr<&u32>, Dim<ndarray::IxDynImpl>>> {
431+
match of {
432+
OrtexTensor::u32(x) => Some(x.view()),
433+
_ => None,
434+
}
435+
}
436+
437+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u32>, Dim<ndarray::IxDynImpl>>> =
438+
tensors.iter().filter_map(|val| filter_u32(val)).collect();
439+
440+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
441+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
442+
OrtexTensor::u32(x)
443+
}
444+
445+
fn concatenate_u64(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
446+
fn filter_u64(
447+
of: &OrtexTensor,
448+
) -> Option<ArrayBase<ndarray::ViewRepr<&u64>, Dim<ndarray::IxDynImpl>>> {
449+
match of {
450+
OrtexTensor::u64(x) => Some(x.view()),
451+
_ => None,
452+
}
453+
}
454+
455+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u64>, Dim<ndarray::IxDynImpl>>> =
456+
tensors.iter().filter_map(|val| filter_u64(val)).collect();
457+
458+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
459+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
460+
OrtexTensor::u64(x)
461+
}
462+
463+
fn concatenate_f16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
464+
fn filter_f16(
465+
of: &OrtexTensor,
466+
) -> Option<ArrayBase<ndarray::ViewRepr<&half::f16>, Dim<ndarray::IxDynImpl>>> {
467+
match of {
468+
OrtexTensor::f16(x) => Some(x.view()),
469+
_ => None,
470+
}
471+
}
472+
473+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&half::f16>, Dim<ndarray::IxDynImpl>>> =
474+
tensors.iter().filter_map(|val| filter_f16(val)).collect();
475+
476+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
477+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
478+
OrtexTensor::f16(x)
479+
}
480+
481+
fn concatenate_bf16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
482+
fn filter_bf16(
483+
of: &OrtexTensor,
484+
) -> Option<ArrayBase<ndarray::ViewRepr<&half::bf16>, Dim<ndarray::IxDynImpl>>> {
485+
match of {
486+
OrtexTensor::bf16(x) => Some(x.view()),
487+
_ => None,
488+
}
489+
}
490+
491+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&half::bf16>, Dim<ndarray::IxDynImpl>>> =
492+
tensors.iter().filter_map(|val| filter_bf16(val)).collect();
493+
494+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
495+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
496+
OrtexTensor::bf16(x)
497+
}
498+
499+
fn concatenate_f32(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
500+
fn filter_f32(
501+
of: &OrtexTensor,
502+
) -> Option<ArrayBase<ndarray::ViewRepr<&f32>, Dim<ndarray::IxDynImpl>>> {
503+
match of {
504+
OrtexTensor::f32(x) => Some(x.view()),
505+
_ => None,
506+
}
507+
}
508+
509+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&f32>, Dim<ndarray::IxDynImpl>>> =
510+
tensors.iter().filter_map(|val| filter_f32(val)).collect();
511+
512+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
513+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
514+
OrtexTensor::f32(x)
515+
}
516+
517+
fn concatenate_f64(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
518+
fn filter_f64(
519+
of: &OrtexTensor,
520+
) -> Option<ArrayBase<ndarray::ViewRepr<&f64>, Dim<ndarray::IxDynImpl>>> {
521+
match of {
522+
OrtexTensor::f64(x) => Some(x.view()),
523+
_ => None,
524+
}
525+
}
526+
527+
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&f64>, Dim<ndarray::IxDynImpl>>> =
528+
tensors.iter().filter_map(|val| filter_f64(val)).collect();
529+
530+
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
531+
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
532+
OrtexTensor::f64(x)
533+
}

0 commit comments

Comments
 (0)