@@ -3,6 +3,7 @@ use ndarray::prelude::*;
33use ndarray:: { ArrayBase , ArrayView , Data , IxDyn } ;
44use ort:: tensor:: { DynOrtTensor , FromArray , InputTensor , TensorElementDataType } ;
55use ort:: OrtError ;
6+ use rustler:: resource:: ResourceArc ;
67use rustler:: Atom ;
78
89use crate :: constants:: ortex_atoms;
@@ -277,3 +278,256 @@ impl std::convert::TryFrom<&DynOrtTensor<'_, IxDyn>> for OrtexTensor {
277278 }
278279 }
279280}
281+
282+ // Currently only supports concatenating tenors of the same type.
283+ //
284+ // This is a similar structure to the above match clauses, except each function
285+ // in map is more complex and needs to be written out explicitly, see below.
286+ //
287+ // Each fn concatenate_{type} verifies to the compiler that the vec<OrtexTensor>
288+ // all have the same type, and then we can concat easily from there
289+ //
290+ // TODO: make the fn concatenate_{type} a macro?
291+ pub fn concatenate (
292+ tensors : Vec < ResourceArc < OrtexTensor > > ,
293+ dtype : ( & str , usize ) ,
294+ axis : usize ,
295+ ) -> OrtexTensor {
296+ match dtype {
297+ ( "s" , 8 ) => concatenate_s8 ( tensors, axis) ,
298+ ( "s" , 16 ) => concatenate_s16 ( tensors, axis) ,
299+ ( "s" , 32 ) => concatenate_s32 ( tensors, axis) ,
300+ ( "s" , 64 ) => concatenate_s64 ( tensors, axis) ,
301+ ( "u" , 8 ) => concatenate_u8 ( tensors, axis) ,
302+ ( "u" , 16 ) => concatenate_u16 ( tensors, axis) ,
303+ ( "u" , 32 ) => concatenate_u32 ( tensors, axis) ,
304+ ( "u" , 64 ) => concatenate_u64 ( tensors, axis) ,
305+ ( "f" , 16 ) => concatenate_f16 ( tensors, axis) ,
306+ ( "bf" , 16 ) => concatenate_bf16 ( tensors, axis) ,
307+ ( "f" , 32 ) => concatenate_f32 ( tensors, axis) ,
308+ ( "f" , 64 ) => concatenate_f64 ( tensors, axis) ,
309+ _ => unimplemented ! ( ) ,
310+ }
311+ }
312+
313+ // each of the below concatenate_{x} functions are identical except for the
314+ // underlying data-type / OrtexTensor enum
315+ fn concatenate_s8 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
316+ // very hacky way to type coalesce, filter_map using an option
317+ fn filter_s8 (
318+ of : & OrtexTensor ,
319+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & i8 > , Dim < ndarray:: IxDynImpl > > > {
320+ match of {
321+ OrtexTensor :: s8( x) => Some ( x. view ( ) ) ,
322+ _ => None ,
323+ }
324+ }
325+
326+ // now all tensors have the same type after filter_map()-ing
327+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & i8 > , Dim < ndarray:: IxDynImpl > > > =
328+ tensors. iter ( ) . filter_map ( |val| filter_s8 ( val) ) . collect ( ) ;
329+
330+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
331+
332+ // because concatenating creates a non-standard data format, we copy the
333+ // data into a standard format shape. Otherwise, when converting to a
334+ // binary, the tensor's data is not ordered properly
335+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
336+ OrtexTensor :: s8 ( x)
337+ }
338+
339+ fn concatenate_s16 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
340+ fn filter_s16 (
341+ of : & OrtexTensor ,
342+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & i16 > , Dim < ndarray:: IxDynImpl > > > {
343+ match of {
344+ OrtexTensor :: s16( x) => Some ( x. view ( ) ) ,
345+ _ => None ,
346+ }
347+ }
348+
349+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & i16 > , Dim < ndarray:: IxDynImpl > > > =
350+ tensors. iter ( ) . filter_map ( |val| filter_s16 ( val) ) . collect ( ) ;
351+
352+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
353+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
354+ OrtexTensor :: s16 ( x)
355+ }
356+
357+ fn concatenate_s32 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
358+ fn filter_s32 (
359+ of : & OrtexTensor ,
360+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & i32 > , Dim < ndarray:: IxDynImpl > > > {
361+ match of {
362+ OrtexTensor :: s32( x) => Some ( x. view ( ) ) ,
363+ _ => None ,
364+ }
365+ }
366+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & i32 > , Dim < ndarray:: IxDynImpl > > > =
367+ tensors. iter ( ) . filter_map ( |val| filter_s32 ( val) ) . collect ( ) ;
368+
369+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
370+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
371+ OrtexTensor :: s32 ( x)
372+ }
373+
374+ fn concatenate_s64 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
375+ fn filter_s64 (
376+ of : & OrtexTensor ,
377+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & i64 > , Dim < ndarray:: IxDynImpl > > > {
378+ match of {
379+ OrtexTensor :: s64( x) => Some ( x. view ( ) ) ,
380+ _ => None ,
381+ }
382+ }
383+
384+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & i64 > , Dim < ndarray:: IxDynImpl > > > =
385+ tensors. iter ( ) . filter_map ( |val| filter_s64 ( val) ) . collect ( ) ;
386+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
387+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
388+ OrtexTensor :: s64 ( x)
389+ }
390+
391+ fn concatenate_u8 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
392+ fn filter_u8 (
393+ of : & OrtexTensor ,
394+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & u8 > , Dim < ndarray:: IxDynImpl > > > {
395+ match of {
396+ OrtexTensor :: u8( x) => Some ( x. view ( ) ) ,
397+ _ => None ,
398+ }
399+ }
400+
401+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & u8 > , Dim < ndarray:: IxDynImpl > > > =
402+ tensors. iter ( ) . filter_map ( |val| filter_u8 ( val) ) . collect ( ) ;
403+
404+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
405+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
406+ OrtexTensor :: u8 ( x)
407+ }
408+
409+ fn concatenate_u16 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
410+ fn filter_u16 (
411+ of : & OrtexTensor ,
412+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & u16 > , Dim < ndarray:: IxDynImpl > > > {
413+ match of {
414+ OrtexTensor :: u16( x) => Some ( x. view ( ) ) ,
415+ _ => None ,
416+ }
417+ }
418+
419+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & u16 > , Dim < ndarray:: IxDynImpl > > > =
420+ tensors. iter ( ) . filter_map ( |val| filter_u16 ( val) ) . collect ( ) ;
421+
422+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
423+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
424+ OrtexTensor :: u16 ( x)
425+ }
426+
427+ fn concatenate_u32 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
428+ fn filter_u32 (
429+ of : & OrtexTensor ,
430+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & u32 > , Dim < ndarray:: IxDynImpl > > > {
431+ match of {
432+ OrtexTensor :: u32( x) => Some ( x. view ( ) ) ,
433+ _ => None ,
434+ }
435+ }
436+
437+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & u32 > , Dim < ndarray:: IxDynImpl > > > =
438+ tensors. iter ( ) . filter_map ( |val| filter_u32 ( val) ) . collect ( ) ;
439+
440+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
441+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
442+ OrtexTensor :: u32 ( x)
443+ }
444+
445+ fn concatenate_u64 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
446+ fn filter_u64 (
447+ of : & OrtexTensor ,
448+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & u64 > , Dim < ndarray:: IxDynImpl > > > {
449+ match of {
450+ OrtexTensor :: u64( x) => Some ( x. view ( ) ) ,
451+ _ => None ,
452+ }
453+ }
454+
455+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & u64 > , Dim < ndarray:: IxDynImpl > > > =
456+ tensors. iter ( ) . filter_map ( |val| filter_u64 ( val) ) . collect ( ) ;
457+
458+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
459+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
460+ OrtexTensor :: u64 ( x)
461+ }
462+
463+ fn concatenate_f16 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
464+ fn filter_f16 (
465+ of : & OrtexTensor ,
466+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & half:: f16 > , Dim < ndarray:: IxDynImpl > > > {
467+ match of {
468+ OrtexTensor :: f16( x) => Some ( x. view ( ) ) ,
469+ _ => None ,
470+ }
471+ }
472+
473+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & half:: f16 > , Dim < ndarray:: IxDynImpl > > > =
474+ tensors. iter ( ) . filter_map ( |val| filter_f16 ( val) ) . collect ( ) ;
475+
476+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
477+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
478+ OrtexTensor :: f16 ( x)
479+ }
480+
481+ fn concatenate_bf16 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
482+ fn filter_bf16 (
483+ of : & OrtexTensor ,
484+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & half:: bf16 > , Dim < ndarray:: IxDynImpl > > > {
485+ match of {
486+ OrtexTensor :: bf16( x) => Some ( x. view ( ) ) ,
487+ _ => None ,
488+ }
489+ }
490+
491+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & half:: bf16 > , Dim < ndarray:: IxDynImpl > > > =
492+ tensors. iter ( ) . filter_map ( |val| filter_bf16 ( val) ) . collect ( ) ;
493+
494+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
495+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
496+ OrtexTensor :: bf16 ( x)
497+ }
498+
499+ fn concatenate_f32 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
500+ fn filter_f32 (
501+ of : & OrtexTensor ,
502+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & f32 > , Dim < ndarray:: IxDynImpl > > > {
503+ match of {
504+ OrtexTensor :: f32( x) => Some ( x. view ( ) ) ,
505+ _ => None ,
506+ }
507+ }
508+
509+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & f32 > , Dim < ndarray:: IxDynImpl > > > =
510+ tensors. iter ( ) . filter_map ( |val| filter_f32 ( val) ) . collect ( ) ;
511+
512+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
513+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
514+ OrtexTensor :: f32 ( x)
515+ }
516+
517+ fn concatenate_f64 ( tensors : Vec < ResourceArc < OrtexTensor > > , axis : usize ) -> OrtexTensor {
518+ fn filter_f64 (
519+ of : & OrtexTensor ,
520+ ) -> Option < ArrayBase < ndarray:: ViewRepr < & f64 > , Dim < ndarray:: IxDynImpl > > > {
521+ match of {
522+ OrtexTensor :: f64( x) => Some ( x. view ( ) ) ,
523+ _ => None ,
524+ }
525+ }
526+
527+ let tensors: Vec < ArrayBase < ndarray:: ViewRepr < & f64 > , Dim < ndarray:: IxDynImpl > > > =
528+ tensors. iter ( ) . filter_map ( |val| filter_f64 ( val) ) . collect ( ) ;
529+
530+ let x = ndarray:: concatenate ( Axis ( axis) , & tensors) . unwrap ( ) ;
531+ let x = Array :: from_shape_vec ( x. raw_dim ( ) , x. iter ( ) . cloned ( ) . collect ( ) ) . unwrap ( ) ;
532+ OrtexTensor :: f64 ( x)
533+ }
0 commit comments