|
1 | 1 | //! Conversions for packing/unpacking `OrtexTensor`s into different types |
2 | 2 | use ndarray::prelude::*; |
3 | | -use ndarray::{ArrayBase, ArrayView, Data, IxDyn}; |
| 3 | +use ndarray::{ArrayBase, ArrayView, Data, IxDyn, ViewRepr, IxDynImpl}; |
4 | 4 | use ort::tensor::{DynOrtTensor, FromArray, InputTensor, TensorElementDataType}; |
5 | 5 | use ort::OrtError; |
6 | 6 | use rustler::resource::ResourceArc; |
@@ -28,6 +28,7 @@ pub enum OrtexTensor { |
28 | 28 | } |
29 | 29 |
|
30 | 30 | impl From<&OrtexTensor> for InputTensor { |
| 31 | + |
31 | 32 | fn from(tensor: &OrtexTensor) -> Self { |
32 | 33 | match tensor { |
33 | 34 | OrtexTensor::s8(y) => InputTensor::from_array(y.clone().into()), |
@@ -282,252 +283,57 @@ impl std::convert::TryFrom<&DynOrtTensor<'_, IxDyn>> for OrtexTensor { |
282 | 283 | // Currently only supports concatenating tenors of the same type. |
283 | 284 | // |
284 | 285 | // This is a similar structure to the above match clauses, except each function |
285 | | -// in map is more complex and needs to be written out explicitly, see below. |
286 | | -// |
287 | | -// Each fn concatenate_{type} verifies to the compiler that the vec<OrtexTensor> |
288 | | -// all have the same type, and then we can concat easily from there |
289 | | -// |
290 | | -// TODO: make the fn concatenate_{type} a macro? |
| 286 | +// in map is more complex and needs to be written out explicitly. To reduce |
| 287 | +// repetition, the concatenate! macro expands that code and makes the necessary |
| 288 | +// minor tweaks |
| 289 | + |
| 290 | +macro_rules! concatenate { |
| 291 | + // `typ` is the actual datatype, `ort_tensor_kind` is the OrtexTensor variant |
| 292 | + ($tensors:expr, $axis:expr, $typ:ty, $ort_tensor_kind:ident) =>{ |
| 293 | + { |
| 294 | + type ArrayType<'a> = ArrayBase<ViewRepr<&'a $typ>, Dim<IxDynImpl>>; |
| 295 | + fn filter(tensor: &OrtexTensor) -> Option<ArrayType> { |
| 296 | + match tensor { |
| 297 | + OrtexTensor::$ort_tensor_kind(x) => Some(x.view()), |
| 298 | + _ => None, |
| 299 | + } |
| 300 | + } |
| 301 | + // hack way to type coalesce. Filters out any ndarray's that don't |
| 302 | + // have the desired type |
| 303 | + let tensors: Vec<ArrayType> = |
| 304 | + $tensors.iter().filter_map(|tensor| { filter(tensor) }).collect(); |
| 305 | + |
| 306 | + let tensors = ndarray::concatenate(Axis($axis), &tensors).unwrap(); |
| 307 | + // data is not contiguous after the concatenation above. To decode |
| 308 | + // properly, need to create a new contiguous vector |
| 309 | + let tensors = Array::from_shape_vec( |
| 310 | + tensors.raw_dim(), |
| 311 | + tensors.iter().cloned().collect()) |
| 312 | + .unwrap(); |
| 313 | + OrtexTensor::$ort_tensor_kind(tensors) |
| 314 | + } |
| 315 | + } |
| 316 | +} |
| 317 | + |
291 | 318 | pub fn concatenate( |
292 | 319 | tensors: Vec<ResourceArc<OrtexTensor>>, |
293 | 320 | dtype: (&str, usize), |
294 | 321 | axis: usize, |
295 | 322 | ) -> OrtexTensor { |
| 323 | + |
296 | 324 | match dtype { |
297 | | - ("s", 8) => concatenate_s8(tensors, axis), |
298 | | - ("s", 16) => concatenate_s16(tensors, axis), |
299 | | - ("s", 32) => concatenate_s32(tensors, axis), |
300 | | - ("s", 64) => concatenate_s64(tensors, axis), |
301 | | - ("u", 8) => concatenate_u8(tensors, axis), |
302 | | - ("u", 16) => concatenate_u16(tensors, axis), |
303 | | - ("u", 32) => concatenate_u32(tensors, axis), |
304 | | - ("u", 64) => concatenate_u64(tensors, axis), |
305 | | - ("f", 16) => concatenate_f16(tensors, axis), |
306 | | - ("bf", 16) => concatenate_bf16(tensors, axis), |
307 | | - ("f", 32) => concatenate_f32(tensors, axis), |
308 | | - ("f", 64) => concatenate_f64(tensors, axis), |
| 325 | + ("s", 8) => concatenate!(tensors, axis, i8, s8), |
| 326 | + ("s", 16) => concatenate!(tensors, axis, i16, s16), |
| 327 | + ("s", 32) => concatenate!(tensors, axis, i32, s32), |
| 328 | + ("s", 64) => concatenate!(tensors, axis, i64, s64), |
| 329 | + ("u", 8) => concatenate!(tensors, axis, u8, u8), |
| 330 | + ("u", 16) => concatenate!(tensors, axis, u16, u16), |
| 331 | + ("u", 32) => concatenate!(tensors, axis, u32, u32), |
| 332 | + ("u", 64) => concatenate!(tensors, axis, u64, u64), |
| 333 | + ("f", 16) => concatenate!(tensors, axis, half::f16, f16), |
| 334 | + ("bf", 16) => concatenate!(tensors, axis, half::bf16, bf16), |
| 335 | + ("f", 32) => concatenate!(tensors, axis, f32, f32), |
| 336 | + ("f", 64) => concatenate!(tensors, axis, f64, f64), |
309 | 337 | _ => unimplemented!(), |
310 | 338 | } |
311 | 339 | } |
312 | | - |
313 | | -// each of the below concatenate_{x} functions are identical except for the |
314 | | -// underlying data-type / OrtexTensor enum |
315 | | -fn concatenate_s8(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
316 | | - // very hacky way to type coalesce, filter_map using an option |
317 | | - fn filter_s8( |
318 | | - of: &OrtexTensor, |
319 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&i8>, Dim<ndarray::IxDynImpl>>> { |
320 | | - match of { |
321 | | - OrtexTensor::s8(x) => Some(x.view()), |
322 | | - _ => None, |
323 | | - } |
324 | | - } |
325 | | - |
326 | | - // now all tensors have the same type after filter_map()-ing |
327 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i8>, Dim<ndarray::IxDynImpl>>> = |
328 | | - tensors.iter().filter_map(|val| filter_s8(val)).collect(); |
329 | | - |
330 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
331 | | - |
332 | | - // because concatenating creates a non-standard data format, we copy the |
333 | | - // data into a standard format shape. Otherwise, when converting to a |
334 | | - // binary, the tensor's data is not ordered properly |
335 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
336 | | - OrtexTensor::s8(x) |
337 | | -} |
338 | | - |
339 | | -fn concatenate_s16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
340 | | - fn filter_s16( |
341 | | - of: &OrtexTensor, |
342 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&i16>, Dim<ndarray::IxDynImpl>>> { |
343 | | - match of { |
344 | | - OrtexTensor::s16(x) => Some(x.view()), |
345 | | - _ => None, |
346 | | - } |
347 | | - } |
348 | | - |
349 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i16>, Dim<ndarray::IxDynImpl>>> = |
350 | | - tensors.iter().filter_map(|val| filter_s16(val)).collect(); |
351 | | - |
352 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
353 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
354 | | - OrtexTensor::s16(x) |
355 | | -} |
356 | | - |
357 | | -fn concatenate_s32(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
358 | | - fn filter_s32( |
359 | | - of: &OrtexTensor, |
360 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&i32>, Dim<ndarray::IxDynImpl>>> { |
361 | | - match of { |
362 | | - OrtexTensor::s32(x) => Some(x.view()), |
363 | | - _ => None, |
364 | | - } |
365 | | - } |
366 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i32>, Dim<ndarray::IxDynImpl>>> = |
367 | | - tensors.iter().filter_map(|val| filter_s32(val)).collect(); |
368 | | - |
369 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
370 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
371 | | - OrtexTensor::s32(x) |
372 | | -} |
373 | | - |
374 | | -fn concatenate_s64(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
375 | | - fn filter_s64( |
376 | | - of: &OrtexTensor, |
377 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&i64>, Dim<ndarray::IxDynImpl>>> { |
378 | | - match of { |
379 | | - OrtexTensor::s64(x) => Some(x.view()), |
380 | | - _ => None, |
381 | | - } |
382 | | - } |
383 | | - |
384 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&i64>, Dim<ndarray::IxDynImpl>>> = |
385 | | - tensors.iter().filter_map(|val| filter_s64(val)).collect(); |
386 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
387 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
388 | | - OrtexTensor::s64(x) |
389 | | -} |
390 | | - |
391 | | -fn concatenate_u8(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
392 | | - fn filter_u8( |
393 | | - of: &OrtexTensor, |
394 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&u8>, Dim<ndarray::IxDynImpl>>> { |
395 | | - match of { |
396 | | - OrtexTensor::u8(x) => Some(x.view()), |
397 | | - _ => None, |
398 | | - } |
399 | | - } |
400 | | - |
401 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u8>, Dim<ndarray::IxDynImpl>>> = |
402 | | - tensors.iter().filter_map(|val| filter_u8(val)).collect(); |
403 | | - |
404 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
405 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
406 | | - OrtexTensor::u8(x) |
407 | | -} |
408 | | - |
409 | | -fn concatenate_u16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
410 | | - fn filter_u16( |
411 | | - of: &OrtexTensor, |
412 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&u16>, Dim<ndarray::IxDynImpl>>> { |
413 | | - match of { |
414 | | - OrtexTensor::u16(x) => Some(x.view()), |
415 | | - _ => None, |
416 | | - } |
417 | | - } |
418 | | - |
419 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u16>, Dim<ndarray::IxDynImpl>>> = |
420 | | - tensors.iter().filter_map(|val| filter_u16(val)).collect(); |
421 | | - |
422 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
423 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
424 | | - OrtexTensor::u16(x) |
425 | | -} |
426 | | - |
427 | | -fn concatenate_u32(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
428 | | - fn filter_u32( |
429 | | - of: &OrtexTensor, |
430 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&u32>, Dim<ndarray::IxDynImpl>>> { |
431 | | - match of { |
432 | | - OrtexTensor::u32(x) => Some(x.view()), |
433 | | - _ => None, |
434 | | - } |
435 | | - } |
436 | | - |
437 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u32>, Dim<ndarray::IxDynImpl>>> = |
438 | | - tensors.iter().filter_map(|val| filter_u32(val)).collect(); |
439 | | - |
440 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
441 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
442 | | - OrtexTensor::u32(x) |
443 | | -} |
444 | | - |
445 | | -fn concatenate_u64(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
446 | | - fn filter_u64( |
447 | | - of: &OrtexTensor, |
448 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&u64>, Dim<ndarray::IxDynImpl>>> { |
449 | | - match of { |
450 | | - OrtexTensor::u64(x) => Some(x.view()), |
451 | | - _ => None, |
452 | | - } |
453 | | - } |
454 | | - |
455 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&u64>, Dim<ndarray::IxDynImpl>>> = |
456 | | - tensors.iter().filter_map(|val| filter_u64(val)).collect(); |
457 | | - |
458 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
459 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
460 | | - OrtexTensor::u64(x) |
461 | | -} |
462 | | - |
463 | | -fn concatenate_f16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
464 | | - fn filter_f16( |
465 | | - of: &OrtexTensor, |
466 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&half::f16>, Dim<ndarray::IxDynImpl>>> { |
467 | | - match of { |
468 | | - OrtexTensor::f16(x) => Some(x.view()), |
469 | | - _ => None, |
470 | | - } |
471 | | - } |
472 | | - |
473 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&half::f16>, Dim<ndarray::IxDynImpl>>> = |
474 | | - tensors.iter().filter_map(|val| filter_f16(val)).collect(); |
475 | | - |
476 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
477 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
478 | | - OrtexTensor::f16(x) |
479 | | -} |
480 | | - |
481 | | -fn concatenate_bf16(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
482 | | - fn filter_bf16( |
483 | | - of: &OrtexTensor, |
484 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&half::bf16>, Dim<ndarray::IxDynImpl>>> { |
485 | | - match of { |
486 | | - OrtexTensor::bf16(x) => Some(x.view()), |
487 | | - _ => None, |
488 | | - } |
489 | | - } |
490 | | - |
491 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&half::bf16>, Dim<ndarray::IxDynImpl>>> = |
492 | | - tensors.iter().filter_map(|val| filter_bf16(val)).collect(); |
493 | | - |
494 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
495 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
496 | | - OrtexTensor::bf16(x) |
497 | | -} |
498 | | - |
499 | | -fn concatenate_f32(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
500 | | - fn filter_f32( |
501 | | - of: &OrtexTensor, |
502 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&f32>, Dim<ndarray::IxDynImpl>>> { |
503 | | - match of { |
504 | | - OrtexTensor::f32(x) => Some(x.view()), |
505 | | - _ => None, |
506 | | - } |
507 | | - } |
508 | | - |
509 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&f32>, Dim<ndarray::IxDynImpl>>> = |
510 | | - tensors.iter().filter_map(|val| filter_f32(val)).collect(); |
511 | | - |
512 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
513 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
514 | | - OrtexTensor::f32(x) |
515 | | -} |
516 | | - |
517 | | -fn concatenate_f64(tensors: Vec<ResourceArc<OrtexTensor>>, axis: usize) -> OrtexTensor { |
518 | | - fn filter_f64( |
519 | | - of: &OrtexTensor, |
520 | | - ) -> Option<ArrayBase<ndarray::ViewRepr<&f64>, Dim<ndarray::IxDynImpl>>> { |
521 | | - match of { |
522 | | - OrtexTensor::f64(x) => Some(x.view()), |
523 | | - _ => None, |
524 | | - } |
525 | | - } |
526 | | - |
527 | | - let tensors: Vec<ArrayBase<ndarray::ViewRepr<&f64>, Dim<ndarray::IxDynImpl>>> = |
528 | | - tensors.iter().filter_map(|val| filter_f64(val)).collect(); |
529 | | - |
530 | | - let x = ndarray::concatenate(Axis(axis), &tensors).unwrap(); |
531 | | - let x = Array::from_shape_vec(x.raw_dim(), x.iter().cloned().collect()).unwrap(); |
532 | | - OrtexTensor::f64(x) |
533 | | -} |
0 commit comments