Skip to content

Commit 8f74cd2

Browse files
authored
Merge pull request #39 from gregszumel/bool_tensors
Boolean input support
2 parents 3eaa913 + 863a56b commit 8f74cd2

File tree

3 files changed

+98
-20
lines changed

3 files changed

+98
-20
lines changed

native/ortex/src/model.rs

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
//! ```
1010
1111
use crate::tensor::OrtexTensor;
12-
use crate::utils::map_opt_level;
13-
use std::convert::{TryFrom, TryInto};
12+
use crate::utils::{is_bool_input, map_opt_level};
13+
use std::convert::TryInto;
14+
use std::iter::zip;
1415

1516
use ort::{Error, ExecutionProviderDispatch, Session};
1617
use rustler::resource::ResourceArc;
@@ -83,27 +84,46 @@ pub fn run(
8384
model: ResourceArc<OrtexModel>,
8485
inputs: Vec<ResourceArc<OrtexTensor>>,
8586
) -> Result<Vec<(ResourceArc<OrtexTensor>, Vec<usize>, Atom, usize)>, Error> {
86-
// TODO: can we handle an error more elegantly than just .unwrap()?
87+
// Grab the session and run a forward pass with it
88+
let session: &ort::Session = &model.session;
8789

8890
let mut ortified_inputs: Vec<ort::SessionInputValue> = Vec::new();
89-
for input in inputs {
90-
let derefed_input: &OrtexTensor = &input;
91-
let v: ort::SessionInputValue = derefed_input.try_into()?;
92-
ortified_inputs.push(v);
93-
}
9491

95-
// Grab the session and run a forward pass with it
96-
let session: &ort::Session = &model.session;
92+
for (elixir_input, onnx_input) in zip(inputs, &session.inputs) {
93+
let derefed_input: &OrtexTensor = &elixir_input;
94+
if is_bool_input(&onnx_input.input_type) {
95+
// this assumes that the boolean input isn't huge -- we're cloning it twice;
96+
// once below, once in the try_into()
97+
let boolified_input: &OrtexTensor = &derefed_input.clone().to_bool();
98+
let v: ort::SessionInputValue = boolified_input.try_into()?;
99+
ortified_inputs.push(v);
100+
} else {
101+
let v: ort::SessionInputValue = derefed_input.try_into()?;
102+
ortified_inputs.push(v);
103+
}
104+
}
97105

98106
// Construct a Vec of ModelOutput enums based on the DynOrtTensor data type
99107
let outputs = session.run(&ortified_inputs[..])?;
100-
outputs
101-
.iter()
102-
.map(|(_name, val)| {
103-
let ortextensor: OrtexTensor = OrtexTensor::try_from(val)?;
104-
let shape = ortextensor.shape();
105-
let (dtype, bits) = ortextensor.dtype();
106-
Ok((ResourceArc::new(ortextensor), shape, dtype, bits))
107-
})
108-
.collect()
108+
let mut collected_outputs = Vec::new();
109+
110+
for output_descriptor in &session.outputs {
111+
let output_name: &str = &output_descriptor.name;
112+
let val = outputs.get(output_name).expect(
113+
&format!(
114+
"Expected {} to be in the outputs, but didn't find it",
115+
output_name
116+
)[..],
117+
);
118+
119+
// NOTE: try_into impl here will implicitly map bool outputs to u8 outputs
120+
let ortextensor: OrtexTensor = val.try_into()?;
121+
let shape = ortextensor.shape();
122+
let (dtype, bits) = ortextensor.dtype();
123+
124+
let collected_output = (ResourceArc::new(ortextensor), shape, dtype, bits);
125+
collected_outputs.push(collected_output)
126+
}
127+
128+
Ok(collected_outputs)
109129
}

native/ortex/src/tensor.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ pub enum OrtexTensor {
2626
bf16(Array<half::bf16, IxDyn>),
2727
f32(Array<f32, IxDyn>),
2828
f64(Array<f64, IxDyn>),
29+
// the bool input is for internal use only.
30+
// Any Nx facing ops should panic if called on a bool input
31+
bool(Array<bool, IxDyn>),
2932
}
3033

3134
impl OrtexTensor {
@@ -43,6 +46,7 @@ impl OrtexTensor {
4346
OrtexTensor::bf16(y) => y.shape().to_owned(),
4447
OrtexTensor::f32(y) => y.shape().to_owned(),
4548
OrtexTensor::f64(y) => y.shape().to_owned(),
49+
_ => panic!("Can't convert this type to Nx format"),
4650
}
4751
}
4852

@@ -108,6 +112,7 @@ impl OrtexTensor {
108112
.into_shape_with_order(shape)
109113
.map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?,
110114
)),
115+
_ => panic!("Can't convert this type to Nx format"),
111116
}
112117
}
113118

@@ -125,6 +130,7 @@ impl OrtexTensor {
125130
OrtexTensor::bf16(_) => (ortex_atoms::bf(), 16),
126131
OrtexTensor::f32(_) => (ortex_atoms::f(), 32),
127132
OrtexTensor::f64(_) => (ortex_atoms::f(), 64),
133+
_ => panic!("Can't convert this type to Nx format"),
128134
}
129135
}
130136

@@ -142,6 +148,7 @@ impl OrtexTensor {
142148
OrtexTensor::bf16(y) => get_bytes(y),
143149
OrtexTensor::f32(y) => get_bytes(y),
144150
OrtexTensor::f64(y) => get_bytes(y),
151+
_ => panic!("Can't convert this type to Nx format"),
145152
};
146153
contents
147154
}
@@ -173,6 +180,25 @@ impl OrtexTensor {
173180
OrtexTensor::bf16(y) => OrtexTensor::bf16(slice_array(y, &slice_specs).to_owned()),
174181
OrtexTensor::f32(y) => OrtexTensor::f32(slice_array(y, &slice_specs).to_owned()),
175182
OrtexTensor::f64(y) => OrtexTensor::f64(slice_array(y, &slice_specs).to_owned()),
183+
_ => panic!("Can't convert this type to Nx format"),
184+
}
185+
}
186+
187+
pub fn to_bool(self) -> OrtexTensor {
188+
match self {
189+
OrtexTensor::u8(y) => {
190+
let bool_tensor = y.to_owned().mapv(|x| match x {
191+
0 => false,
192+
1 => true,
193+
_ => {
194+
panic!(
195+
"Tried to convert a u8 tensor to bool, but not every element is 0 or 1"
196+
)
197+
}
198+
});
199+
OrtexTensor::bool(bool_tensor)
200+
}
201+
t => panic!("Can't convert this type {:?} to bool", t.dtype()),
176202
}
177203
}
178204
}
@@ -253,8 +279,10 @@ impl TryFrom<&Value> for OrtexTensor {
253279
ort::TensorElementType::String => {
254280
todo!("Can't return string tensors")
255281
}
282+
// map the output into u8 space
256283
ort::TensorElementType::Bool => {
257-
todo!("Can't return bool tensors")
284+
let nd_array = e.try_extract_tensor::<bool>()?.into_owned();
285+
OrtexTensor::u8(nd_array.mapv(|x| x as u8))
258286
}
259287
};
260288

@@ -278,11 +306,32 @@ impl TryFrom<&OrtexTensor> for ort::SessionInputValue<'_> {
278306
OrtexTensor::u16(arr) => arr.clone().try_into()?,
279307
OrtexTensor::u32(arr) => arr.clone().try_into()?,
280308
OrtexTensor::u64(arr) => arr.clone().try_into()?,
309+
OrtexTensor::bool(arr) => arr.clone().try_into()?,
281310
};
282311
Ok(r.into())
283312
}
284313
}
285314

315+
impl Clone for OrtexTensor {
316+
fn clone(&self) -> Self {
317+
match self {
318+
OrtexTensor::s8(t) => OrtexTensor::s8(t.clone()),
319+
OrtexTensor::s16(t) => OrtexTensor::s16(t.clone()),
320+
OrtexTensor::s32(t) => OrtexTensor::s32(t.clone()),
321+
OrtexTensor::s64(t) => OrtexTensor::s64(t.clone()),
322+
OrtexTensor::bf16(t) => OrtexTensor::bf16(t.clone()),
323+
OrtexTensor::f16(t) => OrtexTensor::f16(t.clone()),
324+
OrtexTensor::f32(t) => OrtexTensor::f32(t.clone()),
325+
OrtexTensor::f64(t) => OrtexTensor::f64(t.clone()),
326+
OrtexTensor::u8(t) => OrtexTensor::u8(t.clone()),
327+
OrtexTensor::u16(t) => OrtexTensor::u16(t.clone()),
328+
OrtexTensor::u32(t) => OrtexTensor::u32(t.clone()),
329+
OrtexTensor::u64(t) => OrtexTensor::u64(t.clone()),
330+
OrtexTensor::bool(t) => OrtexTensor::bool(t.clone()),
331+
}
332+
}
333+
}
334+
286335
// Currently only supports concatenating tenors of the same type.
287336
//
288337
// This is a similar structure to the above match clauses, except each function

native/ortex/src/utils.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,12 @@ pub fn map_opt_level(opt: i32) -> GraphOptimizationLevel {
116116
_ => GraphOptimizationLevel::Disable,
117117
}
118118
}
119+
120+
pub fn is_bool_input(inp: &ort::ValueType) -> bool {
121+
match inp {
122+
ort::ValueType::Tensor { ty, .. } => ty == &ort::TensorElementType::Bool,
123+
ort::ValueType::Map { value, .. } => value == &ort::TensorElementType::Bool,
124+
ort::ValueType::Sequence(boxed_input) => is_bool_input(boxed_input),
125+
ort::ValueType::Optional(boxed_input) => is_bool_input(boxed_input),
126+
}
127+
}

0 commit comments

Comments
 (0)