Skip to content

Commit 8aa9a52

Browse files
rahulchaphalkarBTOdellabrown
authored
Add device type (#108)
* add device type * retain space between imports and code --------- Co-authored-by: Bradley Odell <[email protected]> Co-authored-by: Andrew Brown <[email protected]>
1 parent 3a404ee commit 8aa9a52

File tree

6 files changed

+106
-9
lines changed

6 files changed

+106
-9
lines changed

crates/openvino/src/core.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
use crate::error::LoadingError;
55
use crate::{cstr, drop_using_function, try_unsafe, util::Result};
66
use crate::{model::CompiledModel, Model};
7-
use crate::{SetupError, Tensor};
7+
use crate::{DeviceType, SetupError, Tensor};
88
use openvino_sys::{
99
self, ov_core_compile_model, ov_core_create, ov_core_create_with_config, ov_core_free,
1010
ov_core_read_model, ov_core_read_model_from_memory_buffer, ov_core_t,
1111
};
12+
use std::ffi::CString;
1213

1314
/// See [`Core`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__core__c__api.html).
1415
pub struct Core {
@@ -68,13 +69,14 @@ impl Core {
6869
}
6970

7071
/// Compile a model to `CompiledModel`.
71-
pub fn compile_model(&mut self, model: &Model, device: &str) -> Result<CompiledModel> {
72+
pub fn compile_model(&mut self, model: &Model, device: DeviceType) -> Result<CompiledModel> {
73+
let device: CString = device.into();
7274
let mut compiled_model = std::ptr::null_mut();
7375
let num_property_args = 0;
7476
try_unsafe!(ov_core_compile_model(
7577
self.ptr,
7678
model.as_ptr(),
77-
cstr!(device),
79+
device.as_ptr(),
7880
num_property_args,
7981
std::ptr::addr_of_mut!(compiled_model)
8082
))?;

crates/openvino/src/device_type.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
use std::borrow::Cow;
2+
use std::convert::Infallible;
3+
use std::ffi::CString;
4+
use std::fmt::{Display, Formatter};
5+
use std::str::FromStr;
6+
7+
/// `DeviceType` represents accelerator devices.
8+
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
9+
pub enum DeviceType<'a> {
10+
/// [CPU Device](https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/cpu-device.html)
11+
CPU,
12+
/// [GPU Device](https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html)
13+
GPU,
14+
/// [NPU Device](https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/npu-device.html)
15+
NPU,
16+
/// [GNA Device](https://docs.openvino.ai/2023.3/openvino_docs_OV_UG_supported_plugins_GNA.html)
17+
#[deprecated = "Deprecated since OpenVINO 2024.0; use NPU device instead"]
18+
GNA,
19+
/// Arbitrary device.
20+
Other(Cow<'a, str>),
21+
}
22+
23+
impl DeviceType<'_> {
24+
/// Creates a device type with owned string data.
25+
pub fn to_owned(&self) -> DeviceType<'static> {
26+
match self {
27+
DeviceType::CPU => DeviceType::CPU,
28+
DeviceType::GPU => DeviceType::GPU,
29+
DeviceType::NPU => DeviceType::NPU,
30+
#[allow(deprecated)]
31+
DeviceType::GNA => DeviceType::GNA,
32+
DeviceType::Other(s) => DeviceType::Other(Cow::Owned(s.clone().into_owned())),
33+
}
34+
}
35+
}
36+
37+
impl AsRef<str> for DeviceType<'_> {
38+
fn as_ref(&self) -> &str {
39+
match self {
40+
DeviceType::CPU => "CPU",
41+
DeviceType::GPU => "GPU",
42+
DeviceType::NPU => "NPU",
43+
#[allow(deprecated)]
44+
DeviceType::GNA => "GNA",
45+
DeviceType::Other(s) => s,
46+
}
47+
}
48+
}
49+
50+
impl<'a> From<&'a DeviceType<'a>> for &'a str {
51+
fn from(value: &'a DeviceType) -> Self {
52+
value.as_ref()
53+
}
54+
}
55+
56+
impl<'a> From<DeviceType<'a>> for CString {
57+
fn from(value: DeviceType) -> Self {
58+
CString::new(value.as_ref()).expect("a valid C string")
59+
}
60+
}
61+
62+
impl<'a> From<&'a str> for DeviceType<'a> {
63+
fn from(s: &'a str) -> Self {
64+
match s {
65+
"CPU" => DeviceType::CPU,
66+
"GPU" => DeviceType::GPU,
67+
"NPU" => DeviceType::NPU,
68+
#[allow(deprecated)]
69+
"GNA" => DeviceType::GNA,
70+
s => DeviceType::Other(Cow::Borrowed(s)),
71+
}
72+
}
73+
}
74+
75+
impl FromStr for DeviceType<'static> {
76+
type Err = Infallible;
77+
78+
fn from_str(s: &str) -> Result<Self, Self::Err> {
79+
Ok(DeviceType::from(s).to_owned())
80+
}
81+
}
82+
83+
impl Display for DeviceType<'_> {
84+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
85+
f.write_str(self.into())
86+
}
87+
}

crates/openvino/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)]
2727

2828
mod core;
29+
mod device_type;
2930
mod dimension;
3031
mod element_type;
3132
mod error;
@@ -42,6 +43,7 @@ mod tensor;
4243
mod util;
4344

4445
pub use crate::core::Core;
46+
pub use device_type::DeviceType;
4547
pub use dimension::Dimension;
4648
pub use element_type::ElementType;
4749
pub use error::{InferenceError, LoadingError, SetupError};

crates/openvino/tests/classify-alexnet.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ mod util;
55

66
use anyhow::Ok;
77
use fixtures::alexnet::Fixture;
8-
use openvino::{prepostprocess, Core, ElementType, Layout, ResizeAlgorithm, Shape, Tensor};
8+
use openvino::{
9+
prepostprocess, Core, DeviceType, ElementType, Layout, ResizeAlgorithm, Shape, Tensor,
10+
};
911
use std::fs;
1012
use util::{Prediction, Predictions};
1113

@@ -45,7 +47,7 @@ fn classify_alexnet() -> anyhow::Result<()> {
4547
let new_model = pre_post_process.build_new_model()?;
4648

4749
// Compile the model and infer the results.
48-
let mut executable_model = core.compile_model(&new_model, "CPU")?;
50+
let mut executable_model = core.compile_model(&new_model, DeviceType::CPU)?;
4951
let mut infer_request = executable_model.create_infer_request()?;
5052
infer_request.set_tensor("data", &tensor)?;
5153
infer_request.infer()?;

crates/openvino/tests/classify-inception.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ mod util;
55

66
use anyhow::Ok;
77
use fixtures::inception::Fixture;
8-
use openvino::{prepostprocess, Core, ElementType, Layout, ResizeAlgorithm, Shape, Tensor};
8+
use openvino::{
9+
prepostprocess, Core, DeviceType, ElementType, Layout, ResizeAlgorithm, Shape, Tensor,
10+
};
911
use std::fs;
1012
use util::{Prediction, Predictions};
1113

@@ -42,7 +44,7 @@ fn classify_inception() -> anyhow::Result<()> {
4244
let new_model = pre_post_process.build_new_model()?;
4345

4446
// Compile the model and infer the results.
45-
let mut executable_model = core.compile_model(&new_model, "CPU")?;
47+
let mut executable_model = core.compile_model(&new_model, DeviceType::CPU)?;
4648
let mut infer_request = executable_model.create_infer_request()?;
4749
infer_request.set_tensor("input", &tensor)?;
4850
infer_request.infer()?;

crates/openvino/tests/classify-mobilenet.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ mod fixtures;
55
mod util;
66

77
use fixtures::mobilenet::Fixture;
8-
use openvino::{prepostprocess, Core, ElementType, Layout, ResizeAlgorithm, Shape, Tensor};
8+
use openvino::{
9+
prepostprocess, Core, DeviceType, ElementType, Layout, ResizeAlgorithm, Shape, Tensor,
10+
};
911
use std::fs;
1012
use util::{Prediction, Predictions};
1113

@@ -45,7 +47,7 @@ fn classify_mobilenet() -> anyhow::Result<()> {
4547
let new_model = pre_post_process.build_new_model()?;
4648

4749
// Compile the model and infer the results.
48-
let mut executable_model = core.compile_model(&new_model, "CPU")?;
50+
let mut executable_model = core.compile_model(&new_model, DeviceType::CPU)?;
4951
let mut infer_request = executable_model.create_infer_request()?;
5052
infer_request.set_tensor("input", &tensor)?;
5153
infer_request.infer()?;

0 commit comments

Comments
 (0)