Skip to content

Commit 0ccb3bf

Browse files
authored
Improve ElementType ergonomics (#102)
* Expose port element type and shape * Expose model's partial shape * Clean up casts * ElementType ergonomics * Return ElementType instead of its u32 reprensentation
1 parent faf8c92 commit 0ccb3bf

File tree

3 files changed

+128
-28
lines changed

3 files changed

+128
-28
lines changed
Lines changed: 120 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,145 @@
1+
use openvino_sys::*;
2+
3+
use std::convert::TryFrom;
4+
use std::error::Error;
5+
use std::fmt;
6+
17
/// `ElementType` represents the type of elements that a tensor can hold. See [`ElementType`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__base__c__api.html#_CPPv417ov_element_type_e).
2-
#[derive(Eq, PartialEq, Copy, Clone, Debug)]
8+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
39
#[repr(u32)]
410
pub enum ElementType {
511
/// An undefined element type.
6-
Undefined = openvino_sys::ov_element_type_e_UNDEFINED,
12+
Undefined = ov_element_type_e_UNDEFINED,
713
/// A dynamic element type.
8-
Dynamic = openvino_sys::ov_element_type_e_DYNAMIC,
14+
Dynamic = ov_element_type_e_DYNAMIC,
915
/// A boolean element type.
10-
Boolean = openvino_sys::ov_element_type_e_OV_BOOLEAN,
16+
Boolean = ov_element_type_e_OV_BOOLEAN,
1117
/// A Bf16 element type.
12-
Bf16 = openvino_sys::ov_element_type_e_BF16,
18+
Bf16 = ov_element_type_e_BF16,
1319
/// A F16 element type.
14-
F16 = openvino_sys::ov_element_type_e_F16,
20+
F16 = ov_element_type_e_F16,
1521
/// A F32 element type.
16-
F32 = openvino_sys::ov_element_type_e_F32,
22+
F32 = ov_element_type_e_F32,
1723
/// A F64 element type.
18-
F64 = openvino_sys::ov_element_type_e_F64,
24+
F64 = ov_element_type_e_F64,
1925
/// A 4-bit integer element type.
20-
I4 = openvino_sys::ov_element_type_e_I4,
26+
I4 = ov_element_type_e_I4,
2127
/// An 8-bit integer element type.
22-
I8 = openvino_sys::ov_element_type_e_I8,
28+
I8 = ov_element_type_e_I8,
2329
/// A 16-bit integer element type.
24-
I16 = openvino_sys::ov_element_type_e_I16,
30+
I16 = ov_element_type_e_I16,
2531
/// A 32-bit integer element type.
26-
I32 = openvino_sys::ov_element_type_e_I32,
32+
I32 = ov_element_type_e_I32,
2733
/// A 64-bit integer element type.
28-
I64 = openvino_sys::ov_element_type_e_I64,
34+
I64 = ov_element_type_e_I64,
2935
/// An 1-bit unsigned integer element type.
30-
U1 = openvino_sys::ov_element_type_e_U1,
36+
U1 = ov_element_type_e_U1,
3137
/// An 4-bit unsigned integer element type.
32-
U4 = openvino_sys::ov_element_type_e_U4,
38+
U4 = ov_element_type_e_U4,
3339
/// An 8-bit unsigned integer element type.
34-
U8 = openvino_sys::ov_element_type_e_U8,
40+
U8 = ov_element_type_e_U8,
3541
/// A 16-bit unsigned integer element type.
36-
U16 = openvino_sys::ov_element_type_e_U16,
42+
U16 = ov_element_type_e_U16,
3743
/// A 32-bit unsigned integer element type.
38-
U32 = openvino_sys::ov_element_type_e_U32,
44+
U32 = ov_element_type_e_U32,
3945
/// A 64-bit unsigned integer element type.
40-
U64 = openvino_sys::ov_element_type_e_U64,
46+
U64 = ov_element_type_e_U64,
4147
/// NF4 element type.
42-
NF4 = openvino_sys::ov_element_type_e_NF4,
48+
NF4 = ov_element_type_e_NF4,
4349
/// F8E4M3 element type.
44-
F8E4M3 = openvino_sys::ov_element_type_e_F8E4M3,
50+
F8E4M3 = ov_element_type_e_F8E4M3,
4551
/// F8E5M3 element type.
46-
F8E5M3 = openvino_sys::ov_element_type_e_F8E5M3,
52+
F8E5M3 = ov_element_type_e_F8E5M3,
53+
}
54+
55+
/// Error returned when attempting to create an [`ElementType`] from an illegal `u32` value.
56+
#[derive(Debug)]
57+
pub struct IllegalValueError(u32);
58+
59+
impl Error for IllegalValueError {}
60+
61+
impl fmt::Display for IllegalValueError {
62+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63+
write!(f, "illegal value: {}", self.0)
64+
}
65+
}
66+
67+
impl TryFrom<u32> for ElementType {
68+
type Error = IllegalValueError;
69+
70+
fn try_from(value: u32) -> Result<Self, Self::Error> {
71+
#[allow(non_upper_case_globals)]
72+
match value {
73+
ov_element_type_e_UNDEFINED => Ok(Self::Undefined),
74+
ov_element_type_e_DYNAMIC => Ok(Self::Dynamic),
75+
ov_element_type_e_OV_BOOLEAN => Ok(Self::Boolean),
76+
ov_element_type_e_BF16 => Ok(Self::Bf16),
77+
ov_element_type_e_F16 => Ok(Self::F16),
78+
ov_element_type_e_F32 => Ok(Self::F32),
79+
ov_element_type_e_F64 => Ok(Self::F64),
80+
ov_element_type_e_I4 => Ok(Self::I4),
81+
ov_element_type_e_I8 => Ok(Self::I8),
82+
ov_element_type_e_I16 => Ok(Self::I16),
83+
ov_element_type_e_I32 => Ok(Self::I32),
84+
ov_element_type_e_I64 => Ok(Self::I64),
85+
ov_element_type_e_U1 => Ok(Self::U1),
86+
ov_element_type_e_U4 => Ok(Self::U4),
87+
ov_element_type_e_U8 => Ok(Self::U8),
88+
ov_element_type_e_U16 => Ok(Self::U16),
89+
ov_element_type_e_U32 => Ok(Self::U32),
90+
ov_element_type_e_U64 => Ok(Self::U64),
91+
ov_element_type_e_NF4 => Ok(Self::NF4),
92+
ov_element_type_e_F8E4M3 => Ok(Self::F8E4M3),
93+
ov_element_type_e_F8E5M3 => Ok(Self::F8E5M3),
94+
_ => Err(IllegalValueError(value)),
95+
}
96+
}
97+
}
98+
99+
impl From<ElementType> for u32 {
100+
fn from(value: ElementType) -> Self {
101+
value as Self
102+
}
103+
}
104+
105+
impl fmt::Display for ElementType {
106+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107+
match self {
108+
Self::Undefined => write!(f, "Undefined"),
109+
Self::Dynamic => write!(f, "Dynamic"),
110+
Self::Boolean => write!(f, "Boolean"),
111+
Self::Bf16 => write!(f, "Bf16"),
112+
Self::F16 => write!(f, "F16"),
113+
Self::F32 => write!(f, "F32"),
114+
Self::F64 => write!(f, "F64"),
115+
Self::I4 => write!(f, "I4"),
116+
Self::I8 => write!(f, "I8"),
117+
Self::I16 => write!(f, "I16"),
118+
Self::I32 => write!(f, "I32"),
119+
Self::I64 => write!(f, "I64"),
120+
Self::U1 => write!(f, "U1"),
121+
Self::U4 => write!(f, "U4"),
122+
Self::U8 => write!(f, "U8"),
123+
Self::U16 => write!(f, "U16"),
124+
Self::U32 => write!(f, "U32"),
125+
Self::U64 => write!(f, "U64"),
126+
Self::NF4 => write!(f, "NF4"),
127+
Self::F8E4M3 => write!(f, "F8E4M3"),
128+
Self::F8E5M3 => write!(f, "F8E5M3"),
129+
}
130+
}
131+
}
132+
133+
#[cfg(test)]
134+
mod tests {
135+
use super::*;
136+
use std::convert::TryInto as _;
137+
138+
#[test]
139+
fn try_from_u32() {
140+
assert_eq!(ElementType::Undefined, 0u32.try_into().unwrap());
141+
let last: u32 = ElementType::F8E5M3.into();
142+
let result: Result<ElementType, _> = (last + 1).try_into();
143+
assert!(result.is_err());
144+
}
47145
}

crates/openvino/src/node.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use openvino_sys::{
44
ov_port_get_element_type, ov_port_get_partial_shape, ov_rank_t, ov_shape_t,
55
};
66

7-
use std::ffi::CStr;
7+
use std::{convert::TryInto as _, ffi::CStr};
88

99
/// See [`Node`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__node__c__api.html).
1010
pub struct Node {
@@ -31,13 +31,13 @@ impl Node {
3131
}
3232

3333
/// Get the data type of elements of the port.
34-
pub fn get_element_type(&self) -> Result<u32> {
34+
pub fn get_element_type(&self) -> Result<ElementType> {
3535
let mut element_type = ElementType::Undefined as u32;
3636
try_unsafe!(ov_port_get_element_type(
3737
self.ptr,
3838
std::ptr::addr_of_mut!(element_type),
3939
))?;
40-
Ok(element_type)
40+
Ok(element_type.try_into().unwrap())
4141
}
4242

4343
/// Get the shape of the port.

crates/openvino/src/tensor.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
//! This module provides functionality related to Tensor objects.
2+
use std::convert::TryInto as _;
3+
24
use crate::element_type::ElementType;
35
use crate::shape::Shape;
46
use crate::{drop_using_function, try_unsafe, util::Result};
@@ -77,13 +79,13 @@ impl Tensor {
7779
}
7880

7981
/// Get the data type of elements of the tensor.
80-
pub fn get_element_type(&self) -> Result<u32> {
82+
pub fn get_element_type(&self) -> Result<ElementType> {
8183
let mut element_type = ElementType::Undefined as u32;
8284
try_unsafe!(ov_tensor_get_element_type(
8385
self.ptr,
8486
std::ptr::addr_of_mut!(element_type),
8587
))?;
86-
Ok(element_type)
88+
Ok(element_type.try_into().unwrap())
8789
}
8890

8991
/// Get the number of elements in the tensor. Product of all dimensions e.g. 1*3*227*227.
@@ -169,7 +171,7 @@ mod tests {
169171
)
170172
.unwrap();
171173
let element_type = tensor.get_element_type().unwrap();
172-
assert_eq!(element_type, ElementType::F32 as u32);
174+
assert_eq!(element_type, ElementType::F32);
173175
}
174176

175177
#[test]

0 commit comments

Comments
 (0)