Skip to content

Commit 4fa5755

Browse files
author
Greg Szumel
committed
formatting tweaks
1 parent 6bbdb33 commit 4fa5755

File tree

2 files changed

+21
-43
lines changed

2 files changed

+21
-43
lines changed

native/ortex/src/tensor.rs

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Conversions for packing/unpacking `OrtexTensor`s into different types
22
use ndarray::prelude::*;
3-
use ndarray::{ArrayBase, ArrayView, Data, IxDyn, ViewRepr, IxDynImpl};
3+
use ndarray::{ArrayBase, ArrayView, Data, IxDyn, IxDynImpl, ViewRepr};
44
use ort::tensor::{DynOrtTensor, FromArray, InputTensor, TensorElementDataType};
55
use ort::OrtError;
66
use rustler::resource::ResourceArc;
@@ -28,7 +28,6 @@ pub enum OrtexTensor {
2828
}
2929

3030
impl From<&OrtexTensor> for InputTensor {
31-
3231
fn from(tensor: &OrtexTensor) -> Self {
3332
match tensor {
3433
OrtexTensor::s8(y) => InputTensor::from_array(y.clone().into()),
@@ -289,38 +288,35 @@ impl std::convert::TryFrom<&DynOrtTensor<'_, IxDyn>> for OrtexTensor {
289288

290289
macro_rules! concatenate {
291290
// `typ` is the actual datatype, `ort_tensor_kind` is the OrtexTensor variant
292-
($tensors:expr, $axis:expr, $typ:ty, $ort_tensor_kind:ident) =>{
293-
{
294-
type ArrayType<'a> = ArrayBase<ViewRepr<&'a $typ>, Dim<IxDynImpl>>;
295-
fn filter(tensor: &OrtexTensor) -> Option<ArrayType> {
296-
match tensor {
297-
OrtexTensor::$ort_tensor_kind(x) => Some(x.view()),
298-
_ => None,
299-
}
291+
($tensors:expr, $axis:expr, $typ:ty, $ort_tensor_kind:ident) => {{
292+
type ArrayType<'a> = ArrayBase<ViewRepr<&'a $typ>, Dim<IxDynImpl>>;
293+
fn filter(tensor: &OrtexTensor) -> Option<ArrayType> {
294+
match tensor {
295+
OrtexTensor::$ort_tensor_kind(x) => Some(x.view()),
296+
_ => None,
300297
}
301-
// hack way to type coalesce. Filters out any ndarray's that don't
302-
// have the desired type
303-
let tensors: Vec<ArrayType> =
304-
$tensors.iter().filter_map(|tensor| { filter(tensor) }).collect();
305-
306-
let tensors = ndarray::concatenate(Axis($axis), &tensors).unwrap();
307-
// data is not contiguous after the concatenation above. To decode
308-
// properly, need to create a new contiguous vector
309-
let tensors = Array::from_shape_vec(
310-
tensors.raw_dim(),
311-
tensors.iter().cloned().collect())
312-
.unwrap();
313-
OrtexTensor::$ort_tensor_kind(tensors)
314298
}
315-
}
299+
// hack way to type coalesce. Filters out any ndarray's that don't
300+
// have the desired type
301+
let tensors: Vec<ArrayType> = $tensors
302+
.iter()
303+
.filter_map(|tensor| filter(tensor))
304+
.collect();
305+
306+
let tensors = ndarray::concatenate(Axis($axis), &tensors).unwrap();
307+
// data is not contiguous after the concatenation above. To decode
308+
// properly, need to create a new contiguous vector
309+
let tensors =
310+
Array::from_shape_vec(tensors.raw_dim(), tensors.iter().cloned().collect()).unwrap();
311+
OrtexTensor::$ort_tensor_kind(tensors)
312+
}};
316313
}
317314

318315
pub fn concatenate(
319316
tensors: Vec<ResourceArc<OrtexTensor>>,
320317
dtype: (&str, usize),
321318
axis: usize,
322319
) -> OrtexTensor {
323-
324320
match dtype {
325321
("s", 8) => concatenate!(tensors, axis, i8, s8),
326322
("s", 16) => concatenate!(tensors, axis, i16, s16),

test/shape/concat_test.exs

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,22 +129,4 @@ defmodule Ortex.TestConcat do
129129
_err = Nx.concatenate([t1, t2])
130130
end
131131
end
132-
133-
# Ignoring these tests, as Nx.Shape takes care of determining if the shape is valid
134-
135-
# test "Concat fails to concat vectors with invalid default axis" do
136-
# assert_raise ArgumentError, "expected all shapes to match {*, 5, 7}, got unmatching shape: {2, 4, 7}", fn() ->
137-
# t1 = Nx.iota({3, 5, 7}) |> Nx.backend_transfer(Ortex.Backend)
138-
# t2 = Nx.iota({2, 4, 7}) |> Nx.backend_transfer(Ortex.Backend)
139-
# _err = Nx.concatenate([t1, t2])
140-
# end
141-
# end
142-
143-
# test "Concat fails to concat vectors with invalid provided axis" do
144-
# assert_raise ArgumentError, "different dims, given axis" do
145-
# t1 = Nx.iota({3, 5, 7}) |> Nx.backend_transfer(Ortex.Backend)
146-
# t2 = Nx.iota({2, 4, 7}) |> Nx.backend_transfer(Ortex.Backend)
147-
# _err = Nx.concatenate([t1, t2], axis: 2)
148-
# end
149-
# end
150132
end

0 commit comments

Comments
 (0)