Skip to content

Commit 863a56b

Browse files
committed
aligning to nx bool handling
1 parent 34c4c98 commit 863a56b

File tree

2 files changed

+15
-21
lines changed

2 files changed

+15
-21
lines changed

native/ortex/src/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ pub fn run(
116116
)[..],
117117
);
118118

119-
// NOTE: try_into impl here will implicitly map bool outputs to signed i8 outputs
119+
// NOTE: try_into impl here will implicitly map bool outputs to u8 outputs
120120
let ortextensor: OrtexTensor = val.try_into()?;
121121
let shape = ortextensor.shape();
122122
let (dtype, bits) = ortextensor.dtype();

native/ortex/src/tensor.rs

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
//! Conversions for packing/unpacking `OrtexTensor`s into different types
22
use core::convert::TryFrom;
3-
use half::{bf16, f16};
43
use ndarray::prelude::*;
54
use ndarray::{ArrayBase, ArrayView, Data, IxDyn, IxDynImpl, ViewRepr};
65
use ort::{DynValue, Error, Value};
@@ -187,24 +186,19 @@ impl OrtexTensor {
187186

188187
pub fn to_bool(self) -> OrtexTensor {
189188
match self {
190-
OrtexTensor::s8(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
191-
OrtexTensor::s16(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
192-
OrtexTensor::s32(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
193-
OrtexTensor::s64(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
194-
OrtexTensor::u8(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
195-
OrtexTensor::u16(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
196-
OrtexTensor::u32(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
197-
OrtexTensor::u64(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)),
198-
OrtexTensor::f16(y) => {
199-
OrtexTensor::bool(y.to_owned().mapv(|x| x != f16::ZERO || x != f16::NEG_ZERO))
189+
OrtexTensor::u8(y) => {
190+
let bool_tensor = y.to_owned().mapv(|x| match x {
191+
0 => false,
192+
1 => true,
193+
_ => {
194+
panic!(
195+
"Tried to convert a u8 tensor to bool, but not every element is 0 or 1"
196+
)
197+
}
198+
});
199+
OrtexTensor::bool(bool_tensor)
200200
}
201-
OrtexTensor::bf16(y) => OrtexTensor::bool(
202-
y.to_owned()
203-
.mapv(|x| x != bf16::ZERO || x != bf16::NEG_ZERO),
204-
),
205-
OrtexTensor::f32(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0.)),
206-
OrtexTensor::f64(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0.)),
207-
_ => panic!("Can't convert this type to bool"),
201+
t => panic!("Can't convert this type {:?} to bool", t.dtype()),
208202
}
209203
}
210204
}
@@ -285,10 +279,10 @@ impl TryFrom<&Value> for OrtexTensor {
285279
ort::TensorElementType::String => {
286280
todo!("Can't return string tensors")
287281
}
288-
// map the output into integer space
282+
// map the output into u8 space
289283
ort::TensorElementType::Bool => {
290284
let nd_array = e.try_extract_tensor::<bool>()?.into_owned();
291-
OrtexTensor::s8(nd_array.mapv(|x| x as i8))
285+
OrtexTensor::u8(nd_array.mapv(|x| x as u8))
292286
}
293287
};
294288

0 commit comments

Comments
 (0)