|
1 | 1 | //! Conversions for packing/unpacking `OrtexTensor`s into different types |
2 | 2 | use ndarray::prelude::*; |
3 | | -use ndarray::{ArrayBase, ArrayView, Data, IxDyn, ViewRepr, IxDynImpl}; |
| 3 | +use ndarray::{ArrayBase, ArrayView, Data, IxDyn, IxDynImpl, ViewRepr}; |
4 | 4 | use ort::tensor::{DynOrtTensor, FromArray, InputTensor, TensorElementDataType}; |
5 | 5 | use ort::OrtError; |
6 | 6 | use rustler::resource::ResourceArc; |
@@ -28,7 +28,6 @@ pub enum OrtexTensor { |
28 | 28 | } |
29 | 29 |
|
30 | 30 | impl From<&OrtexTensor> for InputTensor { |
31 | | - |
32 | 31 | fn from(tensor: &OrtexTensor) -> Self { |
33 | 32 | match tensor { |
34 | 33 | OrtexTensor::s8(y) => InputTensor::from_array(y.clone().into()), |
@@ -289,38 +288,35 @@ impl std::convert::TryFrom<&DynOrtTensor<'_, IxDyn>> for OrtexTensor { |
289 | 288 |
|
290 | 289 | macro_rules! concatenate { |
291 | 290 | // `typ` is the actual datatype, `ort_tensor_kind` is the OrtexTensor variant |
292 | | - ($tensors:expr, $axis:expr, $typ:ty, $ort_tensor_kind:ident) =>{ |
293 | | - { |
294 | | - type ArrayType<'a> = ArrayBase<ViewRepr<&'a $typ>, Dim<IxDynImpl>>; |
295 | | - fn filter(tensor: &OrtexTensor) -> Option<ArrayType> { |
296 | | - match tensor { |
297 | | - OrtexTensor::$ort_tensor_kind(x) => Some(x.view()), |
298 | | - _ => None, |
299 | | - } |
| 291 | + ($tensors:expr, $axis:expr, $typ:ty, $ort_tensor_kind:ident) => {{ |
| 292 | + type ArrayType<'a> = ArrayBase<ViewRepr<&'a $typ>, Dim<IxDynImpl>>; |
| 293 | + fn filter(tensor: &OrtexTensor) -> Option<ArrayType> { |
| 294 | + match tensor { |
| 295 | + OrtexTensor::$ort_tensor_kind(x) => Some(x.view()), |
| 296 | + _ => None, |
300 | 297 | } |
301 | | - // hack way to type coalesce. Filters out any ndarray's that don't |
302 | | - // have the desired type |
303 | | - let tensors: Vec<ArrayType> = |
304 | | - $tensors.iter().filter_map(|tensor| { filter(tensor) }).collect(); |
305 | | - |
306 | | - let tensors = ndarray::concatenate(Axis($axis), &tensors).unwrap(); |
307 | | - // data is not contiguous after the concatenation above. To decode |
308 | | - // properly, need to create a new contiguous vector |
309 | | - let tensors = Array::from_shape_vec( |
310 | | - tensors.raw_dim(), |
311 | | - tensors.iter().cloned().collect()) |
312 | | - .unwrap(); |
313 | | - OrtexTensor::$ort_tensor_kind(tensors) |
314 | 298 | } |
315 | | - } |
| 299 | + // hack way to type coalesce. Filters out any ndarray's that don't |
| 300 | + // have the desired type |
| 301 | + let tensors: Vec<ArrayType> = $tensors |
| 302 | + .iter() |
| 303 | + .filter_map(|tensor| filter(tensor)) |
| 304 | + .collect(); |
| 305 | + |
| 306 | + let tensors = ndarray::concatenate(Axis($axis), &tensors).unwrap(); |
| 307 | + // data is not contiguous after the concatenation above. To decode |
| 308 | + // properly, need to create a new contiguous vector |
| 309 | + let tensors = |
| 310 | + Array::from_shape_vec(tensors.raw_dim(), tensors.iter().cloned().collect()).unwrap(); |
| 311 | + OrtexTensor::$ort_tensor_kind(tensors) |
| 312 | + }}; |
316 | 313 | } |
317 | 314 |
|
318 | 315 | pub fn concatenate( |
319 | 316 | tensors: Vec<ResourceArc<OrtexTensor>>, |
320 | 317 | dtype: (&str, usize), |
321 | 318 | axis: usize, |
322 | 319 | ) -> OrtexTensor { |
323 | | - |
324 | 320 | match dtype { |
325 | 321 | ("s", 8) => concatenate!(tensors, axis, i8, s8), |
326 | 322 | ("s", 16) => concatenate!(tensors, axis, i16, s16), |
|
0 commit comments