Skip to content

Commit 6617fa6

Browse files
add infer_request functions (#111)
Co-authored-by: Bradley Odell <[email protected]>
1 parent b22b850 commit 6617fa6

File tree

1 file changed

+53
-4
lines changed

1 file changed

+53
-4
lines changed

crates/openvino/src/request.rs

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
use crate::tensor::Tensor;
22
use crate::{cstr, drop_using_function, try_unsafe, util::Result};
33
use openvino_sys::{
4-
ov_infer_request_free, ov_infer_request_get_output_tensor_by_index,
5-
ov_infer_request_get_tensor, ov_infer_request_infer,
6-
ov_infer_request_set_input_tensor_by_index, ov_infer_request_set_tensor,
4+
ov_infer_request_cancel, ov_infer_request_free, ov_infer_request_get_input_tensor,
5+
ov_infer_request_get_output_tensor, ov_infer_request_get_output_tensor_by_index,
6+
ov_infer_request_get_tensor, ov_infer_request_infer, ov_infer_request_set_input_tensor,
7+
ov_infer_request_set_input_tensor_by_index, ov_infer_request_set_output_tensor,
8+
ov_infer_request_set_output_tensor_by_index, ov_infer_request_set_tensor,
79
ov_infer_request_start_async, ov_infer_request_t, ov_infer_request_wait_for,
810
};
911

10-
/// See [`InferRequest`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__infer__request__c__api.html).
12+
/// See [`InferRequest`](https://docs.openvino.ai/2024/api/c_cpp_api/group__ov__infer__request__c__api.html).
1113
pub struct InferRequest {
1214
ptr: *mut ov_infer_request_t,
1315
}
@@ -43,6 +45,21 @@ impl InferRequest {
4345
Ok(Tensor::from_ptr(tensor))
4446
}
4547

48+
/// Get an input tensor from the model with only one input tensor.
49+
pub fn get_input_tensor(&self) -> Result<Tensor> {
50+
let mut tensor = std::ptr::null_mut();
51+
try_unsafe!(ov_infer_request_get_input_tensor(
52+
self.ptr,
53+
std::ptr::addr_of_mut!(tensor)
54+
))?;
55+
Ok(Tensor::from_ptr(tensor))
56+
}
57+
58+
/// Set an input tensor for infer models with single input.
59+
pub fn set_input_tensor(&mut self, tensor: &Tensor) -> Result<()> {
60+
try_unsafe!(ov_infer_request_set_input_tensor(self.ptr, tensor.as_ptr()))
61+
}
62+
4663
/// Assing an input [`Tensor`] to the model by its index.
4764
pub fn set_input_tensor_by_index(&mut self, index: usize, tensor: &Tensor) -> Result<()> {
4865
try_unsafe!(ov_infer_request_set_input_tensor_by_index(
@@ -64,11 +81,43 @@ impl InferRequest {
6481
Ok(Tensor::from_ptr(tensor))
6582
}
6683

84+
/// Get an output tensor from the model with only one output tensor.
85+
pub fn get_output_tensor(&self) -> Result<Tensor> {
86+
let mut tensor = std::ptr::null_mut();
87+
try_unsafe!(ov_infer_request_get_output_tensor(
88+
self.ptr,
89+
std::ptr::addr_of_mut!(tensor)
90+
))?;
91+
Ok(Tensor::from_ptr(tensor))
92+
}
93+
94+
/// Set an output tensor to infer models with single output.
95+
pub fn set_output_tensor(&mut self, tensor: &Tensor) -> Result<()> {
96+
try_unsafe!(ov_infer_request_set_output_tensor(
97+
self.ptr,
98+
tensor.as_ptr()
99+
))
100+
}
101+
102+
/// Set an output tensor to infer by the index of output tensor.
103+
pub fn set_output_tensor_by_index(&mut self, index: usize, tensor: &Tensor) -> Result<()> {
104+
try_unsafe!(ov_infer_request_set_output_tensor_by_index(
105+
self.ptr,
106+
index,
107+
tensor.as_ptr()
108+
))
109+
}
110+
67111
/// Execute the inference request.
68112
pub fn infer(&mut self) -> Result<()> {
69113
try_unsafe!(ov_infer_request_infer(self.ptr))
70114
}
71115

116+
/// Cancels inference request.
117+
pub fn cancel(&mut self) -> Result<()> {
118+
try_unsafe!(ov_infer_request_cancel(self.ptr))
119+
}
120+
72121
/// Execute the inference request asyncroneously.
73122
pub fn infer_async(&mut self) -> Result<()> {
74123
try_unsafe!(ov_infer_request_start_async(self.ptr))

0 commit comments

Comments
 (0)