|
| 1 | +use openvino_sys::*; |
| 2 | + |
| 3 | +use std::convert::TryFrom; |
| 4 | +use std::error::Error; |
| 5 | +use std::fmt; |
| 6 | + |
1 | 7 | /// `ElementType` represents the type of elements that a tensor can hold. See [`ElementType`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__base__c__api.html#_CPPv417ov_element_type_e).
|
2 |
| -#[derive(Eq, PartialEq, Copy, Clone, Debug)] |
| 8 | +#[derive(Clone, Copy, Debug, PartialEq, Eq)] |
3 | 9 | #[repr(u32)]
|
4 | 10 | pub enum ElementType {
|
5 | 11 | /// An undefined element type.
|
6 |
| - Undefined = openvino_sys::ov_element_type_e_UNDEFINED, |
| 12 | + Undefined = ov_element_type_e_UNDEFINED, |
7 | 13 | /// A dynamic element type.
|
8 |
| - Dynamic = openvino_sys::ov_element_type_e_DYNAMIC, |
| 14 | + Dynamic = ov_element_type_e_DYNAMIC, |
9 | 15 | /// A boolean element type.
|
10 |
| - Boolean = openvino_sys::ov_element_type_e_OV_BOOLEAN, |
| 16 | + Boolean = ov_element_type_e_OV_BOOLEAN, |
11 | 17 | /// A Bf16 element type.
|
12 |
| - Bf16 = openvino_sys::ov_element_type_e_BF16, |
| 18 | + Bf16 = ov_element_type_e_BF16, |
13 | 19 | /// A F16 element type.
|
14 |
| - F16 = openvino_sys::ov_element_type_e_F16, |
| 20 | + F16 = ov_element_type_e_F16, |
15 | 21 | /// A F32 element type.
|
16 |
| - F32 = openvino_sys::ov_element_type_e_F32, |
| 22 | + F32 = ov_element_type_e_F32, |
17 | 23 | /// A F64 element type.
|
18 |
| - F64 = openvino_sys::ov_element_type_e_F64, |
| 24 | + F64 = ov_element_type_e_F64, |
19 | 25 | /// A 4-bit integer element type.
|
20 |
| - I4 = openvino_sys::ov_element_type_e_I4, |
| 26 | + I4 = ov_element_type_e_I4, |
21 | 27 | /// An 8-bit integer element type.
|
22 |
| - I8 = openvino_sys::ov_element_type_e_I8, |
| 28 | + I8 = ov_element_type_e_I8, |
23 | 29 | /// A 16-bit integer element type.
|
24 |
| - I16 = openvino_sys::ov_element_type_e_I16, |
| 30 | + I16 = ov_element_type_e_I16, |
25 | 31 | /// A 32-bit integer element type.
|
26 |
| - I32 = openvino_sys::ov_element_type_e_I32, |
| 32 | + I32 = ov_element_type_e_I32, |
27 | 33 | /// A 64-bit integer element type.
|
28 |
| - I64 = openvino_sys::ov_element_type_e_I64, |
| 34 | + I64 = ov_element_type_e_I64, |
29 | 35 | /// An 1-bit unsigned integer element type.
|
30 |
| - U1 = openvino_sys::ov_element_type_e_U1, |
| 36 | + U1 = ov_element_type_e_U1, |
31 | 37 | /// An 4-bit unsigned integer element type.
|
32 |
| - U4 = openvino_sys::ov_element_type_e_U4, |
| 38 | + U4 = ov_element_type_e_U4, |
33 | 39 | /// An 8-bit unsigned integer element type.
|
34 |
| - U8 = openvino_sys::ov_element_type_e_U8, |
| 40 | + U8 = ov_element_type_e_U8, |
35 | 41 | /// A 16-bit unsigned integer element type.
|
36 |
| - U16 = openvino_sys::ov_element_type_e_U16, |
| 42 | + U16 = ov_element_type_e_U16, |
37 | 43 | /// A 32-bit unsigned integer element type.
|
38 |
| - U32 = openvino_sys::ov_element_type_e_U32, |
| 44 | + U32 = ov_element_type_e_U32, |
39 | 45 | /// A 64-bit unsigned integer element type.
|
40 |
| - U64 = openvino_sys::ov_element_type_e_U64, |
| 46 | + U64 = ov_element_type_e_U64, |
41 | 47 | /// NF4 element type.
|
42 |
| - NF4 = openvino_sys::ov_element_type_e_NF4, |
| 48 | + NF4 = ov_element_type_e_NF4, |
43 | 49 | /// F8E4M3 element type.
|
44 |
| - F8E4M3 = openvino_sys::ov_element_type_e_F8E4M3, |
| 50 | + F8E4M3 = ov_element_type_e_F8E4M3, |
45 | 51 | /// F8E5M3 element type.
|
46 |
| - F8E5M3 = openvino_sys::ov_element_type_e_F8E5M3, |
| 52 | + F8E5M3 = ov_element_type_e_F8E5M3, |
| 53 | +} |
| 54 | + |
| 55 | +/// Error returned when attempting to create an [`ElementType`] from an illegal `u32` value. |
| 56 | +#[derive(Debug)] |
| 57 | +pub struct IllegalValueError(u32); |
| 58 | + |
| 59 | +impl Error for IllegalValueError {} |
| 60 | + |
| 61 | +impl fmt::Display for IllegalValueError { |
| 62 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 63 | + write!(f, "illegal value: {}", self.0) |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +impl TryFrom<u32> for ElementType { |
| 68 | + type Error = IllegalValueError; |
| 69 | + |
| 70 | + fn try_from(value: u32) -> Result<Self, Self::Error> { |
| 71 | + #[allow(non_upper_case_globals)] |
| 72 | + match value { |
| 73 | + ov_element_type_e_UNDEFINED => Ok(Self::Undefined), |
| 74 | + ov_element_type_e_DYNAMIC => Ok(Self::Dynamic), |
| 75 | + ov_element_type_e_OV_BOOLEAN => Ok(Self::Boolean), |
| 76 | + ov_element_type_e_BF16 => Ok(Self::Bf16), |
| 77 | + ov_element_type_e_F16 => Ok(Self::F16), |
| 78 | + ov_element_type_e_F32 => Ok(Self::F32), |
| 79 | + ov_element_type_e_F64 => Ok(Self::F64), |
| 80 | + ov_element_type_e_I4 => Ok(Self::I4), |
| 81 | + ov_element_type_e_I8 => Ok(Self::I8), |
| 82 | + ov_element_type_e_I16 => Ok(Self::I16), |
| 83 | + ov_element_type_e_I32 => Ok(Self::I32), |
| 84 | + ov_element_type_e_I64 => Ok(Self::I64), |
| 85 | + ov_element_type_e_U1 => Ok(Self::U1), |
| 86 | + ov_element_type_e_U4 => Ok(Self::U4), |
| 87 | + ov_element_type_e_U8 => Ok(Self::U8), |
| 88 | + ov_element_type_e_U16 => Ok(Self::U16), |
| 89 | + ov_element_type_e_U32 => Ok(Self::U32), |
| 90 | + ov_element_type_e_U64 => Ok(Self::U64), |
| 91 | + ov_element_type_e_NF4 => Ok(Self::NF4), |
| 92 | + ov_element_type_e_F8E4M3 => Ok(Self::F8E4M3), |
| 93 | + ov_element_type_e_F8E5M3 => Ok(Self::F8E5M3), |
| 94 | + _ => Err(IllegalValueError(value)), |
| 95 | + } |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +impl From<ElementType> for u32 { |
| 100 | + fn from(value: ElementType) -> Self { |
| 101 | + value as Self |
| 102 | + } |
| 103 | +} |
| 104 | + |
| 105 | +impl fmt::Display for ElementType { |
| 106 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 107 | + match self { |
| 108 | + Self::Undefined => write!(f, "Undefined"), |
| 109 | + Self::Dynamic => write!(f, "Dynamic"), |
| 110 | + Self::Boolean => write!(f, "Boolean"), |
| 111 | + Self::Bf16 => write!(f, "Bf16"), |
| 112 | + Self::F16 => write!(f, "F16"), |
| 113 | + Self::F32 => write!(f, "F32"), |
| 114 | + Self::F64 => write!(f, "F64"), |
| 115 | + Self::I4 => write!(f, "I4"), |
| 116 | + Self::I8 => write!(f, "I8"), |
| 117 | + Self::I16 => write!(f, "I16"), |
| 118 | + Self::I32 => write!(f, "I32"), |
| 119 | + Self::I64 => write!(f, "I64"), |
| 120 | + Self::U1 => write!(f, "U1"), |
| 121 | + Self::U4 => write!(f, "U4"), |
| 122 | + Self::U8 => write!(f, "U8"), |
| 123 | + Self::U16 => write!(f, "U16"), |
| 124 | + Self::U32 => write!(f, "U32"), |
| 125 | + Self::U64 => write!(f, "U64"), |
| 126 | + Self::NF4 => write!(f, "NF4"), |
| 127 | + Self::F8E4M3 => write!(f, "F8E4M3"), |
| 128 | + Self::F8E5M3 => write!(f, "F8E5M3"), |
| 129 | + } |
| 130 | + } |
| 131 | +} |
| 132 | + |
| 133 | +#[cfg(test)] |
| 134 | +mod tests { |
| 135 | + use super::*; |
| 136 | + use std::convert::TryInto as _; |
| 137 | + |
| 138 | + #[test] |
| 139 | + fn try_from_u32() { |
| 140 | + assert_eq!(ElementType::Undefined, 0u32.try_into().unwrap()); |
| 141 | + let last: u32 = ElementType::F8E5M3.into(); |
| 142 | + let result: Result<ElementType, _> = (last + 1).try_into(); |
| 143 | + assert!(result.is_err()); |
| 144 | + } |
47 | 145 | }
|
0 commit comments