Skip to content

Commit 6bbdb33

Browse files
author
Greg Szumel
committed
swapping for macro
1 parent 234b3f9 commit 6bbdb33

File tree

1 file changed

+47
-241
lines changed

1 file changed

+47
-241
lines changed

native/ortex/src/tensor.rs

Lines changed: 47 additions & 241 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Conversions for packing/unpacking `OrtexTensor`s into different types
22
use ndarray::prelude::*;
3-
use ndarray::{ArrayBase, ArrayView, Data, IxDyn};
3+
use ndarray::{ArrayBase, ArrayView, Data, IxDyn, ViewRepr, IxDynImpl};
44
use ort::tensor::{DynOrtTensor, FromArray, InputTensor, TensorElementDataType};
55
use ort::OrtError;
66
use rustler::resource::ResourceArc;
@@ -28,6 +28,7 @@ pub enum OrtexTensor {
2828
}
2929

3030
impl From<&OrtexTensor> for InputTensor {
31+
3132
fn from(tensor: &OrtexTensor) -> Self {
3233
match tensor {
3334
OrtexTensor::s8(y) => InputTensor::from_array(y.clone().into()),
@@ -282,252 +283,57 @@ impl std::convert::TryFrom<&DynOrtTensor<'_, IxDyn>> for OrtexTensor {
282283
// Currently only supports concatenating tenors of the same type.
283284
//
284285
// This is a similar structure to the above match clauses, except each function
285-
// in map is more complex and needs to be written out explicitly, see below.
286-
//
287-
// Each fn concatenate_{type} verifies to the compiler that the vec<OrtexTensor>
288-
// all have the same type, and then we can concat easily from there
289-
//
290-
// TODO: make the fn concatenate_{type} a macro?
286+
// in map is more complex and needs to be written out explicitly. To reduce
287+
// repetition, the concatenate! macro expands that code and makes the necessary
288+
// minor tweaks
289+
290+
macro_rules! concatenate {
291+
// `typ` is the actual datatype, `ort_tensor_kind` is the OrtexTensor variant
292+
($tensors:expr, $axis:expr, $typ:ty, $ort_tensor_kind:ident) =>{
293+
{
294+
type ArrayType<'a> = ArrayBase<ViewRepr<&'a $typ>, Dim<IxDynImpl>>;
295+
fn filter(tensor: &OrtexTensor) -> Option<ArrayType> {
296+
match tensor {
297+
OrtexTensor::$ort_tensor_kind(x) => Some(x.view()),
298+
_ => None,
299+
}
300+
}
301+
// hack way to type coalesce. Filters out any ndarray's that don't
302+
// have the desired type
303+
let tensors: Vec<ArrayType> =
304+
$tensors.iter().filter_map(|tensor| { filter(tensor) }).collect();
305+
306+
let tensors = ndarray::concatenate(Axis($axis), &tensors).unwrap();
307+
// data is not contiguous after the concatenation above. To decode
308+
// properly, need to create a new contiguous vector
309+
let tensors = Array::from_shape_vec(
310+
tensors.raw_dim(),
311+
tensors.iter().cloned().collect())
312+
.unwrap();
313+
OrtexTensor::$ort_tensor_kind(tensors)
314+
}
315+
}
316+
}
317+
291318
pub fn concatenate(
292319
tensors: Vec<ResourceArc<OrtexTensor>>,
293320
dtype: (&str, usize),
294321
axis: usize,
295322
) -> OrtexTensor {
323+
296324
match dtype {
297-
("s", 8) => concatenate_s8(tensors, axis),
298-
("s", 16) => concatenate_s16(tensors, axis),
299-
("s", 32) => concatenate_s32(tensors, axis),
300-
("s", 64) => concatenate_s64(tensors, axis),
301-
("u", 8) => concatenate_u8(tensors, axis),
302-
("u", 16) => concatenate_u16(tensors, axis),
303-
("u", 32) => concatenate_u32(tensors, axis),
304-
("u", 64) => concatenate_u64(tensors, axis),
305-
("f", 16) => concatenate_f16(tensors, axis),
306-
("bf", 16) => concatenate_bf16(tensors, axis),
307-
("f", 32) => concatenate_f32(tensors, axis),
308-
("f", 64) => concatenate_f64(tensors, axis),
325+
("s", 8) => concatenate!(tensors, axis, i8, s8),
326+
("s", 16) => concatenate!(tensors, axis, i16, s16),
327+
("s", 32) => concatenate!(tensors, axis, i32, s32),
328+
("s", 64) => concatenate!(tensors, axis, i64, s64),
329+
("u", 8) => concatenate!(tensors, axis, u8, u8),
330+
("u", 16) => concatenate!(tensors, axis, u16, u16),
331+
("u", 32) => concatenate!(tensors, axis, u32, u32),
332+
("u", 64) => concatenate!(tensors, axis, u64, u64),
333+
("f", 16) => concatenate!(tensors, axis, half::f16, f16),
334+
("bf", 16) => concatenate!(tensors, axis, half::bf16, bf16),
335+
("f", 32) => concatenate!(tensors, axis, f32, f32),
336+
("f", 64) => concatenate!(tensors, axis, f64, f64),
309337
_ => unimplemented!(),
310338
}
311339
}
312-
313-
// each of the below concatenate_{x} functions are identical except for the
314-
// underlying data-type / OrtexTensor enum
315-
fn concatenate_s8(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
316-
// very hacky way to type coalesce, filter_map using an option
317-
fn filter_s8(
318-
of: &OrtexTensor,
319-
) -> Option<ArrayBase<ndarray::ViewRepr<&i8>, Dim<ndarray::IxDynImpl>>> {
320-
match of {
321-
OrtexTensor::s8(x) => Some(x.view()),
322-
_ => None,
323-
}
324-
}
325-
326-
// now all tensors have the same type after filter_map()-ing
327-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i8>, Dim<ndarray::IxDynImpl>>> =
328-
tensors.iter().filter_map(|val| filter_s8(val)).collect();
329-
330-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
331-
332-
// because concatenating creates a non-standard data format, we copy the
333-
// data into a standard format shape. Otherwise, when converting to a
334-
// binary, the tensor's data is not ordered properly
335-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
336-
OrtexTensor::s8(x)
337-
}
338-
339-
fn concatenate_s16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
340-
fn filter_s16(
341-
of: &OrtexTensor,
342-
) -> Option<ArrayBase<ndarray::ViewRepr<&i16>, Dim<ndarray::IxDynImpl>>> {
343-
match of {
344-
OrtexTensor::s16(x) => Some(x.view()),
345-
_ => None,
346-
}
347-
}
348-
349-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i16>, Dim<ndarray::IxDynImpl>>> =
350-
tensors.iter().filter_map(|val| filter_s16(val)).collect();
351-
352-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
353-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
354-
OrtexTensor::s16(x)
355-
}
356-
357-
fn concatenate_s32(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
358-
fn filter_s32(
359-
of: &OrtexTensor,
360-
) -> Option<ArrayBase<ndarray::ViewRepr<&i32>, Dim<ndarray::IxDynImpl>>> {
361-
match of {
362-
OrtexTensor::s32(x) => Some(x.view()),
363-
_ => None,
364-
}
365-
}
366-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i32>, Dim<ndarray::IxDynImpl>>> =
367-
tensors.iter().filter_map(|val| filter_s32(val)).collect();
368-
369-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
370-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
371-
OrtexTensor::s32(x)
372-
}
373-
374-
fn concatenate_s64(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
375-
fn filter_s64(
376-
of: &OrtexTensor,
377-
) -> Option<ArrayBase<ndarray::ViewRepr<&i64>, Dim<ndarray::IxDynImpl>>> {
378-
match of {
379-
OrtexTensor::s64(x) => Some(x.view()),
380-
_ => None,
381-
}
382-
}
383-
384-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i64>, Dim<ndarray::IxDynImpl>>> =
385-
tensors.iter().filter_map(|val| filter_s64(val)).collect();
386-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
387-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
388-
OrtexTensor::s64(x)
389-
}
390-
391-
fn concatenate_u8(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
392-
fn filter_u8(
393-
of: &OrtexTensor,
394-
) -> Option<ArrayBase<ndarray::ViewRepr<&u8>, Dim<ndarray::IxDynImpl>>> {
395-
match of {
396-
OrtexTensor::u8(x) => Some(x.view()),
397-
_ => None,
398-
}
399-
}
400-
401-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u8>, Dim<ndarray::IxDynImpl>>> =
402-
tensors.iter().filter_map(|val| filter_u8(val)).collect();
403-
404-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
405-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
406-
OrtexTensor::u8(x)
407-
}
408-
409-
fn concatenate_u16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
410-
fn filter_u16(
411-
of: &OrtexTensor,
412-
) -> Option<ArrayBase<ndarray::ViewRepr<&u16>, Dim<ndarray::IxDynImpl>>> {
413-
match of {
414-
OrtexTensor::u16(x) => Some(x.view()),
415-
_ => None,
416-
}
417-
}
418-
419-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u16>, Dim<ndarray::IxDynImpl>>> =
420-
tensors.iter().filter_map(|val| filter_u16(val)).collect();
421-
422-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
423-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
424-
OrtexTensor::u16(x)
425-
}
426-
427-
fn concatenate_u32(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
428-
fn filter_u32(
429-
of: &OrtexTensor,
430-
) -> Option<ArrayBase<ndarray::ViewRepr<&u32>, Dim<ndarray::IxDynImpl>>> {
431-
match of {
432-
OrtexTensor::u32(x) => Some(x.view()),
433-
_ => None,
434-
}
435-
}
436-
437-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u32>, Dim<ndarray::IxDynImpl>>> =
438-
tensors.iter().filter_map(|val| filter_u32(val)).collect();
439-
440-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
441-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
442-
OrtexTensor::u32(x)
443-
}
444-
445-
fn concatenate_u64(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
446-
fn filter_u64(
447-
of: &OrtexTensor,
448-
) -> Option<ArrayBase<ndarray::ViewRepr<&u64>, Dim<ndarray::IxDynImpl>>> {
449-
match of {
450-
OrtexTensor::u64(x) => Some(x.view()),
451-
_ => None,
452-
}
453-
}
454-
455-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u64>, Dim<ndarray::IxDynImpl>>> =
456-
tensors.iter().filter_map(|val| filter_u64(val)).collect();
457-
458-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
459-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
460-
OrtexTensor::u64(x)
461-
}
462-
463-
fn concatenate_f16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
464-
fn filter_f16(
465-
of: &OrtexTensor,
466-
) -> Option<ArrayBase<ndarray::ViewRepr<&half::f16>, Dim<ndarray::IxDynImpl>>> {
467-
match of {
468-
OrtexTensor::f16(x) => Some(x.view()),
469-
_ => None,
470-
}
471-
}
472-
473-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&half::f16>, Dim<ndarray::IxDynImpl>>> =
474-
tensors.iter().filter_map(|val| filter_f16(val)).collect();
475-
476-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
477-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
478-
OrtexTensor::f16(x)
479-
}
480-
481-
fn concatenate_bf16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
482-
fn filter_bf16(
483-
of: &OrtexTensor,
484-
) -> Option<ArrayBase<ndarray::ViewRepr<&half::bf16>, Dim<ndarray::IxDynImpl>>> {
485-
match of {
486-
OrtexTensor::bf16(x) => Some(x.view()),
487-
_ => None,
488-
}
489-
}
490-
491-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&half::bf16>, Dim<ndarray::IxDynImpl>>> =
492-
tensors.iter().filter_map(|val| filter_bf16(val)).collect();
493-
494-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
495-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
496-
OrtexTensor::bf16(x)
497-
}
498-
499-
fn concatenate_f32(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
500-
fn filter_f32(
501-
of: &OrtexTensor,
502-
) -> Option<ArrayBase<ndarray::ViewRepr<&f32>, Dim<ndarray::IxDynImpl>>> {
503-
match of {
504-
OrtexTensor::f32(x) => Some(x.view()),
505-
_ => None,
506-
}
507-
}
508-
509-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&f32>, Dim<ndarray::IxDynImpl>>> =
510-
tensors.iter().filter_map(|val| filter_f32(val)).collect();
511-
512-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
513-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
514-
OrtexTensor::f32(x)
515-
}
516-
517-
fn concatenate_f64(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor {
518-
fn filter_f64(
519-
of: &OrtexTensor,
520-
) -> Option<ArrayBase<ndarray::ViewRepr<&f64>, Dim<ndarray::IxDynImpl>>> {
521-
match of {
522-
OrtexTensor::f64(x) => Some(x.view()),
523-
_ => None,
524-
}
525-
}
526-
527-
let tensors: Vec<ArrayBase<ndarray::ViewRepr<&f64>, Dim<ndarray::IxDynImpl>>> =
528-
tensors.iter().filter_map(|val| filter_f64(val)).collect();
529-
530-
let x = ndarray::concatenate(Axis(axis), &tensors).unwrap();
531-
let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap();
532-
OrtexTensor::f64(x)
533-
}

0 commit comments

Comments
 (0)