Skip to content

Commit 78b0865

Browse files
committed
adding bool support
1 parent f78d11f commit 78b0865

File tree

3 files changed

+104
-20
lines changed

3 files changed

+104
-20
lines changed

native/ortex/src/model.rs

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
//! ```
1010
1111
use crate::tensor::OrtexTensor;
12-
use crate::utils::map_opt_level;
13-
use std::convert::{TryFrom, TryInto};
12+
use crate::utils::{is_bool_input, map_opt_level};
13+
use std::convert::TryInto;
14+
use std::iter::zip;
1415

1516
use ort::{Error, ExecutionProviderDispatch, Session};
1617
use rustler::resource::ResourceArc;
@@ -83,27 +84,46 @@ pub fn run(
8384
model: ResourceArc<OrtexModel>,
8485
inputs: Vec<ResourceArc<OrtexTensor>>,
8586
) -> Result<Vec<(ResourceArc<OrtexTensor>, Vec<usize>, Atom, usize)>, Error> {
86-
// TODO: can we handle an error more elegantly than just .unwrap()?
87+
// Grab the session and run a forward pass with it
88+
let session: &ort::Session = &model.session;
8789

8890
let mut ortified_inputs: Vec<ort::SessionInputValue> = Vec::new();
89-
for input in inputs {
90-
let derefed_input: &OrtexTensor = &input;
91-
let v: ort::SessionInputValue = derefed_input.try_into()?;
92-
ortified_inputs.push(v);
93-
}
9491

95-
// Grab the session and run a forward pass with it
96-
let session: &ort::Session = &model.session;
92+
for (elixir_input, onnx_input) in zip(inputs, &session.inputs) {
93+
let derefed_input: &OrtexTensor = &elixir_input;
94+
if is_bool_input(&onnx_input.input_type) {
95+
// this assumes that the boolean input isn't huge -- we're cloning it twice;
96+
// once below, once in the try_into()
97+
let boolified_input: &OrtexTensor = &derefed_input.clone().to_bool();
98+
let v: ort::SessionInputValue = boolified_input.try_into()?;
99+
ortified_inputs.push(v);
100+
} else {
101+
let v: ort::SessionInputValue = derefed_input.try_into()?;
102+
ortified_inputs.push(v);
103+
}
104+
}
97105

98106
// Construct a Vec of ModelOutput enums based on the DynOrtTensor data type
99107
let outputs = session.run(&ortified_inputs[..])?;
100-
outputs
101-
.iter()
102-
.map(|(_name, val)| {
103-
let ortextensor: OrtexTensor = OrtexTensor::try_from(val)?;
104-
let shape = ortextensor.shape();
105-
let (dtype, bits) = ortextensor.dtype();
106-
Ok((ResourceArc::new(ortextensor), shape, dtype, bits))
107-
})
108-
.collect()
108+
let mut collected_outputs = Vec::new();
109+
110+
for output_descriptor in &session.outputs {
111+
let output_name: &str = &output_descriptor.name;
112+
let val = outputs.get(output_name).expect(
113+
&format!(
114+
"Expected {} to be in the outputs, but didn't find it",
115+
output_name
116+
)[..],
117+
);
118+
119+
// NOTE: try_into impl here will implicitly map bool outputs to signed i8 outputs
120+
let ortextensor: OrtexTensor = val.try_into()?;
121+
let shape = ortextensor.shape();
122+
let (dtype, bits) = ortextensor.dtype();
123+
124+
let collected_output = (ResourceArc::new(ortextensor), shape, dtype, bits);
125+
collected_outputs.push(collected_output)
126+
}
127+
128+
Ok(collected_outputs)
109129
}

native/ortex/src/tensor.rs

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Conversions for packing/unpacking `OrtexTensor`s into different types
22
use core::convert::TryFrom;
3+
use half::{bf16, f16};
34
use ndarray::prelude::*;
45
use ndarray::{ArrayBase, ArrayView, Data, IxDyn, IxDynImpl, ViewRepr};
56
use ort::{DynValue, Error, Value};
@@ -26,6 +27,9 @@ pub enum OrtexTensor {
2627
bf16(Array<half::bf16, IxDyn>),
2728
f32(Array<f32, IxDyn>),
2829
f64(Array<f64, IxDyn>),
30+
// the bool input is for internal use only.
31+
// Any Nx facing ops should panic if called on a bool input
32+
bool(Array<bool, IxDyn>),
2933
}
3034

3135
impl OrtexTensor {
@@ -43,6 +47,7 @@ impl OrtexTensor {
4347
OrtexTensor::bf16(y) => y.shape().to_owned(),
4448
OrtexTensor::f32(y) => y.shape().to_owned(),
4549
OrtexTensor::f64(y) => y.shape().to_owned(),
50+
_ => panic!("Can't convert this type to Nx format"),
4651
}
4752
}
4853

@@ -108,6 +113,7 @@ impl OrtexTensor {
108113
.into_shape_with_order(shape)
109114
.map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?,
110115
)),
116+
_ => panic!("Can't convert this type to Nx format"),
111117
}
112118
}
113119

@@ -125,6 +131,7 @@ impl OrtexTensor {
125131
OrtexTensor::bf16(_) => (ortex_atoms::bf(), 16),
126132
OrtexTensor::f32(_) => (ortex_atoms::f(), 32),
127133
OrtexTensor::f64(_) => (ortex_atoms::f(), 64),
134+
_ => panic!("Can't convert this type to Nx format"),
128135
}
129136
}
130137

@@ -142,6 +149,7 @@ impl OrtexTensor {
142149
OrtexTensor::bf16(y) => get_bytes(y),
143150
OrtexTensor::f32(y) => get_bytes(y),
144151
OrtexTensor::f64(y) => get_bytes(y),
152+
_ => panic!("Can't convert this type to Nx format"),
145153
};
146154
contents
147155
}
@@ -173,6 +181,30 @@ impl OrtexTensor {
173181
OrtexTensor::bf16(y) => OrtexTensor::bf16(slice_array(y, &slice_specs).to_owned()),
174182
OrtexTensor::f32(y) => OrtexTensor::f32(slice_array(y, &slice_specs).to_owned()),
175183
OrtexTensor::f64(y) => OrtexTensor::f64(slice_array(y, &slice_specs).to_owned()),
184+
_ => panic!("Can't convert this type to Nx format"),
185+
}
186+
}
187+
188+
pub fn to_bool(self) -> OrtexTensor {
189+
match self {
190+
OrtexTensor::s8(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
191+
OrtexTensor::s16(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
192+
OrtexTensor::s32(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
193+
OrtexTensor::s64(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
194+
OrtexTensor::u8(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
195+
OrtexTensor::u16(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
196+
OrtexTensor::u32(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
197+
OrtexTensor::u64(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
198+
OrtexTensor::f16(y) => {
199+
OrtexTensor::bool(y.to_owned().mapv(|x| x != f16::ZERO || x != f16::NEG_ZERO))
200+
}
201+
OrtexTensor::bf16(y) => OrtexTensor::bool(
202+
y.to_owned()
203+
.mapv(|x| x != bf16::ZERO || x != bf16::NEG_ZERO),
204+
),
205+
OrtexTensor::f32(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0.)),
206+
OrtexTensor::f64(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0.)),
207+
_ => panic!("Can't convert this type to bool"),
176208
}
177209
}
178210
}
@@ -253,8 +285,10 @@ impl TryFrom<&Value> for OrtexTensor {
253285
ort::TensorElementType::String => {
254286
todo!("Can't return string tensors")
255287
}
288+
// map the output into integer space
256289
ort::TensorElementType::Bool => {
257-
todo!("Can't return bool tensors")
290+
let nd_array = e.try_extract_tensor::<bool>()?.into_owned();
291+
OrtexTensor::s8(nd_array.mapv(|x| x as i8))
258292
}
259293
};
260294

@@ -278,11 +312,32 @@ impl TryFrom<&OrtexTensor> for ort::SessionInputValue<'_> {
278312
OrtexTensor::u16(arr) => arr.clone().try_into()?,
279313
OrtexTensor::u32(arr) => arr.clone().try_into()?,
280314
OrtexTensor::u64(arr) => arr.clone().try_into()?,
315+
OrtexTensor::bool(arr) => arr.clone().try_into()?,
281316
};
282317
Ok(r.into())
283318
}
284319
}
285320

321+
impl Clone for OrtexTensor {
322+
fn clone(&self) -> Self {
323+
match self {
324+
OrtexTensor::s8(t) => OrtexTensor::s8(t.clone()),
325+
OrtexTensor::s16(t) => OrtexTensor::s16(t.clone()),
326+
OrtexTensor::s32(t) => OrtexTensor::s32(t.clone()),
327+
OrtexTensor::s64(t) => OrtexTensor::s64(t.clone()),
328+
OrtexTensor::bf16(t) => OrtexTensor::bf16(t.clone()),
329+
OrtexTensor::f16(t) => OrtexTensor::f16(t.clone()),
330+
OrtexTensor::f32(t) => OrtexTensor::f32(t.clone()),
331+
OrtexTensor::f64(t) => OrtexTensor::f64(t.clone()),
332+
OrtexTensor::u8(t) => OrtexTensor::u8(t.clone()),
333+
OrtexTensor::u16(t) => OrtexTensor::u16(t.clone()),
334+
OrtexTensor::u32(t) => OrtexTensor::u32(t.clone()),
335+
OrtexTensor::u64(t) => OrtexTensor::u64(t.clone()),
336+
OrtexTensor::bool(t) => OrtexTensor::bool(t.clone()),
337+
}
338+
}
339+
}
340+
286341
// Currently only supports concatenating tenors of the same type.
287342
//
288343
// This is a similar structure to the above match clauses, except each function

native/ortex/src/utils.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,12 @@ pub fn map_opt_level(opt: i32) -> GraphOptimizationLevel {
116116
_ => GraphOptimizationLevel::Disable,
117117
}
118118
}
119+
120+
pub fn is_bool_input(inp: &ort::ValueType) -> bool {
121+
match inp {
122+
ort::ValueType::Tensor { ty, .. } => ty == &ort::TensorElementType::Bool,
123+
ort::ValueType::Map { value, .. } => value == &ort::TensorElementType::Bool,
124+
ort::ValueType::Sequence(boxed_input) => is_bool_input(boxed_input),
125+
ort::ValueType::Optional(boxed_input) => is_bool_input(boxed_input),
126+
}
127+
}

0 commit comments

Comments
 (0)