Skip to content

Commit 58528ef

Browse files
authored
Merge pull request #23 from boydjohnson/small-fixes
Add query_logical_tensor to CompiledPartition
2 parents 2aa3116 + f84cbf4 commit 58528ef

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

src/graph/compiled_partition.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
use onednnl_sys::{
22
dnnl_graph_compiled_partition_create, dnnl_graph_compiled_partition_destroy,
3-
dnnl_graph_compiled_partition_execute, dnnl_graph_compiled_partition_t, dnnl_status_t,
3+
dnnl_graph_compiled_partition_execute, dnnl_graph_compiled_partition_query_logical_tensor,
4+
dnnl_graph_compiled_partition_t, dnnl_status_t,
45
};
56

67
use crate::{
78
error::DnnlError,
8-
graph::{partition::OneDNNGraphPartition, tensor::tensor::Tensor},
9+
graph::{
10+
partition::OneDNNGraphPartition,
11+
tensor::{logical::LogicalTensor, tensor::Tensor},
12+
},
913
stream::Stream,
1014
};
1115

@@ -60,6 +64,28 @@ impl CompiledPartition {
6064

6165
Ok(())
6266
}
67+
68+
pub fn query_logical_tensor(&self, index: usize) -> Result<LogicalTensor, DnnlError> {
69+
let mut logical_tensor = std::mem::MaybeUninit::uninit();
70+
let status = unsafe {
71+
dnnl_graph_compiled_partition_query_logical_tensor(
72+
self.handle,
73+
index,
74+
logical_tensor.as_mut_ptr(),
75+
)
76+
};
77+
78+
if status != dnnl_status_t::dnnl_success {
79+
return Err(status.into());
80+
}
81+
82+
let lt = unsafe {
83+
LogicalTensor {
84+
handle: logical_tensor.assume_init(),
85+
}
86+
};
87+
Ok(lt)
88+
}
6389
}
6490

6591
impl Drop for CompiledPartition {

0 commit comments

Comments
 (0)