|
1 | 1 | //! Conversions for packing/unpacking `OrtexTensor`s into different types |
2 | 2 | use core::convert::TryFrom; |
3 | | -use half::{bf16, f16}; |
4 | 3 | use ndarray::prelude::*; |
5 | 4 | use ndarray::{ArrayBase, ArrayView, Data, IxDyn, IxDynImpl, ViewRepr}; |
6 | 5 | use ort::{DynValue, Error, Value}; |
@@ -187,24 +186,19 @@ impl OrtexTensor { |
187 | 186 |
|
188 | 187 | pub fn to_bool(self) -> OrtexTensor { |
189 | 188 | match self { |
190 | | - OrtexTensor::s8(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)), |
191 | | - OrtexTensor::s16(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)), |
192 | | - OrtexTensor::s32(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)), |
193 | | - OrtexTensor::s64(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)), |
194 | | - OrtexTensor::u8(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)), |
195 | | - OrtexTensor::u16(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)), |
196 | | - OrtexTensor::u32(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)), |
197 | | - OrtexTensor::u64(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0)), |
198 | | - OrtexTensor::f16(y) => { |
199 | | - OrtexTensor::bool(y.to_owned().mapv(|x| x != f16::ZERO || x != f16::NEG_ZERO)) |
| 189 | + OrtexTensor::u8(y) => { |
| 190 | + let bool_tensor = y.to_owned().mapv(|x| match x { |
| 191 | + 0 => false, |
| 192 | + 1 => true, |
| 193 | + _ => { |
| 194 | + panic!( |
| 195 | + "Tried to convert a u8 tensor to bool, but not every element is 0 or 1" |
| 196 | + ) |
| 197 | + } |
| 198 | + }); |
| 199 | + OrtexTensor::bool(bool_tensor) |
200 | 200 | } |
201 | | - OrtexTensor::bf16(y) => OrtexTensor::bool( |
202 | | - y.to_owned() |
203 | | - .mapv(|x| x != bf16::ZERO || x != bf16::NEG_ZERO), |
204 | | - ), |
205 | | - OrtexTensor::f32(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0.)), |
206 | | - OrtexTensor::f64(y) => OrtexTensor::bool(y.to_owned().mapv(|x| x != 0.)), |
207 | | - _ => panic!("Can't convert this type to bool"), |
| 201 | + t => panic!("Can't convert this type {:?} to bool", t.dtype()), |
208 | 202 | } |
209 | 203 | } |
210 | 204 | } |
@@ -285,10 +279,10 @@ impl TryFrom<&Value> for OrtexTensor { |
285 | 279 | ort::TensorElementType::String => { |
286 | 280 | todo!("Can't return string tensors") |
287 | 281 | } |
288 | | - // map the output into integer space |
| 282 | + // map the output into u8 space |
289 | 283 | ort::TensorElementType::Bool => { |
290 | 284 | let nd_array = e.try_extract_tensor::<bool>()?.into_owned(); |
291 | | - OrtexTensor::s8(nd_array.mapv(|x| x as i8)) |
| 285 | + OrtexTensor::u8(nd_array.mapv(|x| x as u8)) |
292 | 286 | } |
293 | 287 | }; |
294 | 288 |
|
|
0 commit comments