diff --git a/.github/workflows/ci-intel.yml b/.github/workflows/ci-intel.yml index 1d7584f..37aabfe 100644 --- a/.github/workflows/ci-intel.yml +++ b/.github/workflows/ci-intel.yml @@ -1,64 +1,64 @@ name: Build onednnl safe bindings for Intel on: - push: - branches: - - main - pull_request: - branches: - - main + push: + branches: + - main + pull_request: + branches: + - main jobs: - build: - runs-on: ubuntu-24.04 - env: - RUSTFLAGS: -D warnings - strategy: - matrix: - rust: - - stable - - nightly + build: + runs-on: ubuntu-24.04 + env: + RUSTFLAGS: -D warnings + strategy: + matrix: + rust: + - stable + - nightly - steps: - - name: Checkout code - uses: actions/checkout@v4 + steps: + - name: Checkout code + uses: actions/checkout@v4 - - name: Install Rust toolchain - uses: dtolnay/rust-toolchain@master - with: - toolchain: ${{ matrix.rust }} + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} - - name: Cache oneAPI installation - id: cache-oneapi - uses: actions/cache@v3 - with: - path: /opt/intel/oneapi - key: oneapi-${{ env.CACHE_NUMBER }} - restore-keys: | - oneapi- + - name: Cache oneAPI installation + id: cache-oneapi + uses: actions/cache@v3 + with: + path: /opt/intel/oneapi + key: oneapi-${{ env.CACHE_NUMBER }} + restore-keys: | + oneapi- - - name: Set up Intel oneAPI APT repository - if: steps.cache-oneapi.outputs.cache-hit != 'true' - run: | - wget -qO- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | sudo gpg --dearmor -o /usr/share/keyrings/oneapi-archive-keyring.gpg - echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list - sudo apt update + - name: Set up Intel oneAPI APT repository + if: steps.cache-oneapi.outputs.cache-hit != 'true' + run: | + wget -qO- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | sudo gpg --dearmor -o /usr/share/keyrings/oneapi-archive-keyring.gpg + echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list + sudo apt update - - name: Install Intel oneAPI DNNL version 2025.0.1 - if: steps.cache-oneapi.outputs.cache-hit != 'true' - run: sudo apt install -y intel-oneapi-dnnl-devel=2025.0.1-* + - name: Install Intel oneAPI OneDNN package 2025.2 + if: steps.cache-oneapi.outputs.cache-hit != 'true' + run: sudo apt install -y intel-oneapi-dnnl-devel-2025.2 - - name: Build - run: | - source /opt/intel/oneapi/setvars.sh - cargo build + - name: Build + run: | + source /opt/intel/oneapi/setvars.sh + cargo build - - name: Run tests - run: | - source /opt/intel/oneapi/setvars.sh - cargo test - - - name: Run docs - run: | - source /opt/intel/oneapi/setvars.sh - cargo doc + - name: Run tests + run: | + source /opt/intel/oneapi/setvars.sh + cargo test + + - name: Run docs + run: | + source /opt/intel/oneapi/setvars.sh + cargo doc diff --git a/src/graph.rs b/src/graph.rs new file mode 100644 index 0000000..758ad1a --- /dev/null +++ b/src/graph.rs @@ -0,0 +1,7 @@ +pub mod compiled_partition; +pub mod graph; +pub mod op; +pub mod ops_builders; +pub mod partition; +pub mod spec; +pub mod tensor; diff --git a/src/graph/compiled_partition.rs b/src/graph/compiled_partition.rs new file mode 100644 index 0000000..51948e2 --- /dev/null +++ b/src/graph/compiled_partition.rs @@ -0,0 +1,71 @@ +use onednnl_sys::{ + dnnl_graph_compiled_partition_create, dnnl_graph_compiled_partition_destroy, + dnnl_graph_compiled_partition_execute, dnnl_graph_compiled_partition_t, dnnl_status_t, +}; + +use crate::{ + error::DnnlError, + graph::{partition::OneDNNGraphPartition, tensor::tensor::Tensor}, + stream::Stream, +}; + +pub struct CompiledPartition { + pub(crate) handle: dnnl_graph_compiled_partition_t, + pub(crate) partition: OneDNNGraphPartition, +} + +impl CompiledPartition { + pub fn create(partition: OneDNNGraphPartition) -> Result { + let mut handle = std::ptr::null_mut(); + let status = unsafe { dnnl_graph_compiled_partition_create(&mut handle, partition.handle) }; + if status != dnnl_status_t::dnnl_success { + return Err(status.into()); + } + Ok(CompiledPartition { handle, partition }) + } + + pub fn execute( + &self, + stream: &Stream, + inputs: &[Tensor], + outputs: &[&mut Tensor], + ) -> Result<(), DnnlError> { + // Collect the input tensor handles into a vector. This ensures the collection + // of pointers has a stable memory location that lives long enough for the C call. + let mut input_handles: Vec<_> = inputs.iter().map(|t| t.handle as *const _).collect(); + + // Do the same for the output tensor handles. + let mut output_handles: Vec<_> = outputs.iter().map(|t| t.handle as *const _).collect(); + + // The C API expects the number of inputs/outputs as an integer type. + let num_inputs = input_handles.len(); + let num_outputs = output_handles.len(); + + let status = unsafe { + // Now, we pass pointers to our vectors' data, which are guaranteed to be + // valid for the duration of this call. + dnnl_graph_compiled_partition_execute( + self.handle, + stream.handle, + num_inputs, + input_handles.as_mut_ptr(), + num_outputs, + output_handles.as_mut_ptr(), // The C API uses these handles to find the output buffers + ) + }; + + if status != dnnl_status_t::dnnl_success { + return Err(status.into()); + } + + Ok(()) + } +} + +impl Drop for CompiledPartition { + fn drop(&mut self) { + unsafe { + dnnl_graph_compiled_partition_destroy(self.handle); + } + } +} diff --git a/src/graph/graph.rs b/src/graph/graph.rs new file mode 100644 index 0000000..00addf3 --- /dev/null +++ b/src/graph/graph.rs @@ -0,0 +1,141 @@ +use { + super::op::OneDNNGraphOp, + crate::{error::DnnlError, graph::partition::OneDNNGraphPartition}, + onednnl_sys::{ + dnnl_graph_add_op, dnnl_graph_graph_create, dnnl_graph_graph_create_with_fpmath_mode, + dnnl_graph_graph_destroy, dnnl_graph_graph_filter, dnnl_graph_graph_finalize, + dnnl_graph_graph_get_fpmath_mode, dnnl_graph_graph_get_partition_num, + dnnl_graph_graph_get_partitions, dnnl_graph_graph_is_finalized, dnnl_graph_graph_t, + dnnl_status_t, + }, +}; + +pub struct OneDNNGraph { + handle: dnnl_graph_graph_t, + ops: Vec, +} + +impl OneDNNGraph { + pub fn new(engine_type: onednnl_sys::dnnl_engine_kind_t::Type) -> Result { + let mut handle: dnnl_graph_graph_t = std::ptr::null_mut(); + let status = unsafe { dnnl_graph_graph_create(&mut handle, engine_type) }; + if status == dnnl_status_t::dnnl_success { + Ok(Self { + handle, + ops: Vec::new(), + }) + } else { + Err(status.into()) + } + } + + pub fn new_with_fpmath_mode( + engine_type: onednnl_sys::dnnl_engine_kind_t::Type, + fp_mode: onednnl_sys::dnnl_fpmath_mode_t::Type, + ) -> Result { + let mut handle: dnnl_graph_graph_t = std::ptr::null_mut(); + let status = + unsafe { dnnl_graph_graph_create_with_fpmath_mode(&mut handle, engine_type, fp_mode) }; + if status == dnnl_status_t::dnnl_success { + Ok(Self { + handle, + ops: Vec::new(), + }) + } else { + Err(status.into()) + } + } + + pub fn filter( + &self, + policy: onednnl_sys::dnnl_graph_partition_policy_t::Type, + ) -> Result<(), DnnlError> { + let status = unsafe { dnnl_graph_graph_filter(self.handle, policy) }; + if status == dnnl_status_t::dnnl_success { + Ok(()) + } else { + Err(status.into()) + } + } + + pub fn finalize(&self) -> Result<(), DnnlError> { + let status = unsafe { dnnl_graph_graph_finalize(self.handle) }; + if status == dnnl_status_t::dnnl_success { + Ok(()) + } else { + Err(status.into()) + } + } + + pub fn ops(&self) -> &[OneDNNGraphOp] { + &self.ops + } + + pub fn is_finalized(&self) -> Result { + let mut is_finalized = 0; + let status = unsafe { dnnl_graph_graph_is_finalized(self.handle, &mut is_finalized) }; + if status == dnnl_status_t::dnnl_success { + Ok(is_finalized != 0) + } else { + Err(status.into()) + } + } + + pub fn get_fpmath_mode( + &self, + ) -> Result<(onednnl_sys::dnnl_fpmath_mode_t::Type, i32), DnnlError> { + let mut mode = onednnl_sys::dnnl_fpmath_mode_t::dnnl_fpmath_mode_strict; + let mut apply_to_int = 0; + + let status = + unsafe { dnnl_graph_graph_get_fpmath_mode(self.handle, &mut mode, &mut apply_to_int) }; + if status == dnnl_status_t::dnnl_success { + Ok((mode, apply_to_int)) + } else { + Err(status.into()) + } + } + pub fn get_partition_num(&self) -> Result { + let mut num = 0; + + let status = unsafe { dnnl_graph_graph_get_partition_num(self.handle, &mut num) }; + if status == dnnl_status_t::dnnl_success { + Ok(num) + } else { + Err(status.into()) + } + } + + pub fn get_partitions(&self) -> Result, DnnlError> { + let num = self.get_partition_num()?; + let mut partitions = Vec::with_capacity(num); + + let status = + unsafe { dnnl_graph_graph_get_partitions(self.handle, num, partitions.as_mut_ptr()) }; + if status == dnnl_status_t::dnnl_success { + unsafe { partitions.set_len(partitions.capacity()) }; + Ok(partitions + .into_iter() + .map(|p| OneDNNGraphPartition { handle: p }) + .collect()) + } else { + Err(status.into()) + } + } + + pub fn add_op(&mut self, op: OneDNNGraphOp) -> Result<(), DnnlError> { + let status = unsafe { dnnl_graph_add_op(self.handle, op.handle) }; + self.ops.push(op); + if status == dnnl_status_t::dnnl_success { + Ok(()) + } else { + Err(status.into()) + } + } +} + +impl Drop for OneDNNGraph { + fn drop(&mut self) { + unsafe { dnnl_graph_graph_destroy(self.handle) }; + } +} diff --git a/src/graph/op.rs b/src/graph/op.rs new file mode 100644 index 0000000..057ba9f --- /dev/null +++ b/src/graph/op.rs @@ -0,0 +1,104 @@ +use { + crate::{ + error::DnnlError, + graph::{spec::AttrValue, tensor::logical::LogicalTensor}, + }, + onednnl_sys::{ + dnnl_graph_op_add_input, dnnl_graph_op_add_output, dnnl_graph_op_attr_t, + dnnl_graph_op_create, dnnl_graph_op_destroy, dnnl_graph_op_kind_t, + dnnl_graph_op_set_attr_bool, dnnl_graph_op_set_attr_f32, dnnl_graph_op_set_attr_s64, + dnnl_graph_op_set_attr_str, dnnl_graph_op_t, dnnl_status_t, + }, + std::ffi::CString, +}; + +pub struct OneDNNGraphOp { + pub(crate) handle: dnnl_graph_op_t, +} + +pub type OneDNNGraphOpType = dnnl_graph_op_kind_t::Type; + +impl OneDNNGraphOp { + pub const ABS: OneDNNGraphOpType = dnnl_graph_op_kind_t::dnnl_graph_op_abs; + pub const ABS_BACKWARD: OneDNNGraphOpType = dnnl_graph_op_kind_t::dnnl_graph_op_abs_backward; + pub const ADD: OneDNNGraphOpType = dnnl_graph_op_kind_t::dnnl_graph_op_add; + pub const AVG_POOL: OneDNNGraphOpType = dnnl_graph_op_kind_t::dnnl_graph_op_avg_pool; + pub const AVG_POOL_BACKWARD: OneDNNGraphOpType = + dnnl_graph_op_kind_t::dnnl_graph_op_avg_pool_backward; + pub const CONVOLUTION: OneDNNGraphOpType = dnnl_graph_op_kind_t::dnnl_graph_op_convolution; + pub const CLAMP: OneDNNGraphOpType = dnnl_graph_op_kind_t::dnnl_graph_op_clamp; + pub const CONCAT: OneDNNGraphOpType = dnnl_graph_op_kind_t::dnnl_graph_op_concat; + pub const MATMUL: OneDNNGraphOpType = dnnl_graph_op_kind_t::dnnl_graph_op_matmul; + pub const SOFTMAX: OneDNNGraphOpType = dnnl_graph_op_kind_t::dnnl_graph_op_softmax; + pub const STATIC_RESHAPE: OneDNNGraphOpType = + dnnl_graph_op_kind_t::dnnl_graph_op_static_reshape; + pub const REORDER: OneDNNGraphOpType = dnnl_graph_op_kind_t::dnnl_graph_op_reorder; + + pub fn new( + id: usize, + kind: OneDNNGraphOpType, + verbose_name: impl AsRef, + ) -> Result { + let c_string = CString::new(verbose_name.as_ref()).unwrap(); + + let mut handle = std::ptr::null_mut(); + let status = unsafe { dnnl_graph_op_create(&mut handle, id, kind, c_string.as_ptr()) }; + if status == dnnl_status_t::dnnl_success { + Ok(Self { handle }) + } else { + Err(status.into()) + } + } + + pub fn add_input(&mut self, tensor: &LogicalTensor) -> Result<(), DnnlError> { + let status = unsafe { dnnl_graph_op_add_input(self.handle, &tensor.handle) }; + if status == dnnl_status_t::dnnl_success { + Ok(()) + } else { + Err(status.into()) + } + } + + pub fn add_output(&mut self, tensor: &LogicalTensor) -> Result<(), DnnlError> { + let status = unsafe { dnnl_graph_op_add_output(self.handle, &tensor.handle) }; + if status == dnnl_status_t::dnnl_success { + Ok(()) + } else { + Err(status.into()) + } + } + + pub fn set_attribute( + &mut self, + name: &dnnl_graph_op_attr_t::Type, + value: &AttrValue, + ) -> Result<(), DnnlError> { + let status = match value { + AttrValue::Bool(value) => unsafe { + dnnl_graph_op_set_attr_bool(self.handle, *name, value.as_ptr(), value.len()) + }, + AttrValue::Int(value) => unsafe { + dnnl_graph_op_set_attr_s64(self.handle, *name, value.as_ptr(), value.len()) + }, + AttrValue::Float(value) => unsafe { + dnnl_graph_op_set_attr_f32(self.handle, *name, value.as_ptr(), value.len()) + }, + AttrValue::Str(value) => { + let l = value.len(); + let c_string = CString::new(value.as_str()).unwrap(); + unsafe { dnnl_graph_op_set_attr_str(self.handle, *name, c_string.as_ptr(), l) } + } + }; + if status == dnnl_status_t::dnnl_success { + Ok(()) + } else { + Err(status.into()) + } + } +} + +impl Drop for OneDNNGraphOp { + fn drop(&mut self) { + unsafe { dnnl_graph_op_destroy(self.handle) }; + } +} diff --git a/src/graph/ops_builders.rs b/src/graph/ops_builders.rs new file mode 100644 index 0000000..66893d0 --- /dev/null +++ b/src/graph/ops_builders.rs @@ -0,0 +1,120 @@ +use { + super::spec::OpSpec, + crate::{ + error::DnnlError, + graph::{ + op::OneDNNGraphOp, + ops_builders::{ + abs::AbsSpec, abs_backward::AbsBackwardSpec, add::AddSpec, avg_pool::AvgPoolSpec, + avg_pool_backward::AvgPoolBackwardSpec, + batch_norm_inference::BatchNormInferenceSpec, clamp::ClampSpec, + convolution::ConvolutionSpec, elu::EluSpec, end::EndSpec, exp::ExpSpec, + matmul::MatMulSpec, reorder::ReorderSpec, softmax::SoftmaxSpec, + static_reshape::StaticReshapeSpec, + }, + spec::{AttrValue, RequiredAttrs}, + tensor::logical::LogicalTensor, + }, + }, + std::marker::PhantomData, +}; + +pub mod abs; +pub mod abs_backward; +pub mod add; +pub mod avg_pool; +pub mod avg_pool_backward; +pub mod batch_norm_inference; +pub mod clamp; +pub mod concat; +pub mod convolution; +pub mod elu; +pub mod end; +pub mod exp; +pub mod matmul; +pub mod reorder; +pub mod softmax; +pub mod static_reshape; + +pub type OpAttrKind = onednnl_sys::dnnl_graph_op_attr_t::Type; + +pub struct OpBuilder { + id: usize, + inputs: Vec, + outputs: Vec, + attrs: Vec<(OpAttrKind, AttrValue)>, + _marker: PhantomData, +} + +impl OpBuilder { + pub fn new(id: usize) -> Self { + Self { + id, + inputs: vec![], + outputs: vec![], + attrs: vec![], + _marker: PhantomData, + } + } + + pub fn required(mut self, r: impl Into) -> Self { + if let RequiredAttrs::Some(iter) = r.into() { + for (k, v) in iter { + self.attrs.push((k, v)); + } + } + self + } + + pub fn with_input(mut self, t: LogicalTensor) -> Self { + self.inputs.push(t); + self + } + + pub fn with_output(mut self, t: LogicalTensor) -> Self { + self.outputs.push(t); + self + } + + pub fn with_extra_attr( + mut self, + key: impl Into, + val: impl Into, + ) -> Self { + self.attrs.push((key.into(), val.into())); + self + } + + pub fn build(self, verbose_name: &str) -> Result { + let mut op = OneDNNGraphOp::new(self.id, K::KIND, verbose_name)?; + + for (k, v) in self.attrs { + op.set_attribute(&k, &v)?; + } + + for t in self.inputs { + op.add_input(&t)?; + } + for t in self.outputs { + op.add_output(&t)?; + } + + Ok(op) + } +} + +pub type AbsOpBuilder = OpBuilder; +pub type AbsBackwardOpBuilder = OpBuilder; +pub type AddOpBuilder = OpBuilder; +pub type MatMulOpBuilder = OpBuilder; +pub type ClampOpBuilder = OpBuilder; +pub type AvgPoolOpBuilder = OpBuilder; +pub type ConvOpBuilder = OpBuilder; +pub type SoftmaxOpBuilder = OpBuilder; +pub type EndOpBuilder = OpBuilder; +pub type StaticReshapeOpBuilder = OpBuilder; +pub type ReorderOpBuilder = OpBuilder; +pub type AvgPoolBackwardOpBuilder = OpBuilder; +pub type BatchNormInferenceOpBuilder = OpBuilder; +pub type EluOpBuilder = OpBuilder; +pub type ExpOpBuilder = OpBuilder; diff --git a/src/graph/ops_builders/abs.rs b/src/graph/ops_builders/abs.rs new file mode 100644 index 0000000..a5b89f3 --- /dev/null +++ b/src/graph/ops_builders/abs.rs @@ -0,0 +1,10 @@ +use crate::graph::{ + op::{OneDNNGraphOp, OneDNNGraphOpType}, + spec::OpSpec, +}; + +pub struct AbsSpec; + +impl OpSpec for AbsSpec { + const KIND: OneDNNGraphOpType = OneDNNGraphOp::ABS; +} diff --git a/src/graph/ops_builders/abs_backward.rs b/src/graph/ops_builders/abs_backward.rs new file mode 100644 index 0000000..43c4bfc --- /dev/null +++ b/src/graph/ops_builders/abs_backward.rs @@ -0,0 +1,10 @@ +use crate::graph::{ + op::{OneDNNGraphOp, OneDNNGraphOpType}, + spec::OpSpec, +}; + +pub struct AbsBackwardSpec; + +impl OpSpec for AbsBackwardSpec { + const KIND: OneDNNGraphOpType = OneDNNGraphOp::ABS_BACKWARD; +} diff --git a/src/graph/ops_builders/add.rs b/src/graph/ops_builders/add.rs new file mode 100644 index 0000000..17047db --- /dev/null +++ b/src/graph/ops_builders/add.rs @@ -0,0 +1,19 @@ +use { + crate::graph::{ + op::{OneDNNGraphOp, OneDNNGraphOpType}, + spec::OpSpec, + }, + onednnl_sys::dnnl_graph_op_attr_t, +}; + +pub struct AddSpec; + +impl OpSpec for AddSpec { + const KIND: OneDNNGraphOpType = OneDNNGraphOp::ADD; +} + +impl AddSpec { + /// Possible values of "none" and "numpy" + pub const AUTO_BROADCAST: dnnl_graph_op_attr_t::Type = + dnnl_graph_op_attr_t::dnnl_graph_op_attr_auto_broadcast; +} diff --git a/src/graph/ops_builders/avg_pool.rs b/src/graph/ops_builders/avg_pool.rs new file mode 100644 index 0000000..15cece4 --- /dev/null +++ b/src/graph/ops_builders/avg_pool.rs @@ -0,0 +1,46 @@ +use onednnl_sys::dnnl_graph_op_attr_t::{ + self, dnnl_graph_op_attr_auto_pad, dnnl_graph_op_attr_data_format, + dnnl_graph_op_attr_exclude_pad, dnnl_graph_op_attr_kernel, dnnl_graph_op_attr_pads_begin, + dnnl_graph_op_attr_pads_end, dnnl_graph_op_attr_rounding_type, dnnl_graph_op_attr_strides, +}; + +use crate::graph::{ + op::{OneDNNGraphOp, OneDNNGraphOpType}, + spec::{OpSpec, RequiredAttrs}, +}; + +pub struct AvgPoolSpec; + +impl OpSpec for AvgPoolSpec { + const KIND: OneDNNGraphOpType = OneDNNGraphOp::AVG_POOL; +} + +impl AvgPoolSpec { + pub const ROUNDING_TYPE: dnnl_graph_op_attr_t::Type = dnnl_graph_op_attr_rounding_type; + pub const AUTO_PAD: dnnl_graph_op_attr_t::Type = dnnl_graph_op_attr_auto_pad; + pub const DATA_FORMAT: dnnl_graph_op_attr_t::Type = dnnl_graph_op_attr_data_format; +} + +#[derive(Debug, Clone)] +pub struct AvgPoolAttrs { + pub strides: Vec, + pub pads_begin: Vec, + pub pads_end: Vec, + pub exclude_pad: bool, + pub kernel: Vec, +} + +impl From for RequiredAttrs { + fn from(attrs: AvgPoolAttrs) -> Self { + RequiredAttrs::Some(vec![ + (dnnl_graph_op_attr_strides, attrs.strides.into()), + (dnnl_graph_op_attr_pads_begin, attrs.pads_begin.into()), + (dnnl_graph_op_attr_pads_end, attrs.pads_end.into()), + ( + dnnl_graph_op_attr_exclude_pad, + vec![attrs.exclude_pad as u8].into(), + ), + (dnnl_graph_op_attr_kernel, attrs.kernel.into()), + ]) + } +} diff --git a/src/graph/ops_builders/avg_pool_backward.rs b/src/graph/ops_builders/avg_pool_backward.rs new file mode 100644 index 0000000..b7d92f7 --- /dev/null +++ b/src/graph/ops_builders/avg_pool_backward.rs @@ -0,0 +1,25 @@ +use onednnl_sys::{dnnl_graph_op_attr_t::dnnl_graph_op_attr_data_format, dnnl_graph_op_kind_t}; + +use crate::graph::{ + op::OneDNNGraphOpType, + spec::{OpSpec, RequiredAttrs}, +}; + +pub struct AvgPoolBackwardSpec; + +impl OpSpec for AvgPoolBackwardSpec { + const KIND: OneDNNGraphOpType = dnnl_graph_op_kind_t::dnnl_graph_op_avg_pool_backward; +} + +pub struct AvgPoolBackwardAttrs { + pub data_format: String, +} + +impl From for RequiredAttrs { + fn from(attrs: AvgPoolBackwardAttrs) -> Self { + RequiredAttrs::Some(vec![( + dnnl_graph_op_attr_data_format, + attrs.data_format.into(), + )]) + } +} diff --git a/src/graph/ops_builders/batch_norm_inference.rs b/src/graph/ops_builders/batch_norm_inference.rs new file mode 100644 index 0000000..3f875ec --- /dev/null +++ b/src/graph/ops_builders/batch_norm_inference.rs @@ -0,0 +1,23 @@ +use onednnl_sys::dnnl_graph_op_attr_t::dnnl_graph_op_attr_epsilon; + +use crate::graph::spec::RequiredAttrs; + +pub struct BatchNormInferenceSpec; + +pub struct BatchNormInferenceAttrs { + pub epsilon: f32, +} + +impl BatchNormInferenceSpec { + pub const DATA_FORMAT: onednnl_sys::dnnl_graph_op_attr_t::Type = + onednnl_sys::dnnl_graph_op_attr_t::dnnl_graph_op_attr_data_format; +} + +impl From for RequiredAttrs { + fn from(attrs: BatchNormInferenceAttrs) -> Self { + RequiredAttrs::Some(vec![( + dnnl_graph_op_attr_epsilon, + vec![attrs.epsilon].into(), + )]) + } +} diff --git a/src/graph/ops_builders/clamp.rs b/src/graph/ops_builders/clamp.rs new file mode 100644 index 0000000..39101fb --- /dev/null +++ b/src/graph/ops_builders/clamp.rs @@ -0,0 +1,27 @@ +use onednnl_sys::dnnl_graph_op_attr_t::{dnnl_graph_op_attr_max, dnnl_graph_op_attr_min}; + +use crate::graph::{ + op::{OneDNNGraphOp, OneDNNGraphOpType}, + spec::{OpSpec, RequiredAttrs}, +}; + +pub struct ClampSpec; + +impl OpSpec for ClampSpec { + const KIND: OneDNNGraphOpType = OneDNNGraphOp::CLAMP; +} + +#[derive(Debug, Clone, Copy)] +pub struct ClampAttrs { + pub min: f32, + pub max: f32, +} + +impl From for RequiredAttrs { + fn from(attrs: ClampAttrs) -> Self { + RequiredAttrs::Some(vec![ + (dnnl_graph_op_attr_min, vec![attrs.min].into()), + (dnnl_graph_op_attr_max, vec![attrs.max].into()), + ]) + } +} diff --git a/src/graph/ops_builders/concat.rs b/src/graph/ops_builders/concat.rs new file mode 100644 index 0000000..2f6621a --- /dev/null +++ b/src/graph/ops_builders/concat.rs @@ -0,0 +1,23 @@ +use onednnl_sys::dnnl_graph_op_attr_t::dnnl_graph_op_attr_axis; + +use crate::graph::{ + op::{OneDNNGraphOp, OneDNNGraphOpType}, + spec::{OpSpec, RequiredAttrs}, +}; + +pub struct ConcatSpec; + +impl OpSpec for ConcatSpec { + const KIND: OneDNNGraphOpType = OneDNNGraphOp::CONCAT; +} + +#[derive(Debug, Clone, Copy)] +pub struct ConcatAttrs { + pub axis: i64, +} + +impl From for RequiredAttrs { + fn from(attrs: ConcatAttrs) -> Self { + RequiredAttrs::Some(vec![(dnnl_graph_op_attr_axis, vec![attrs.axis].into())]) + } +} diff --git a/src/graph/ops_builders/convolution.rs b/src/graph/ops_builders/convolution.rs new file mode 100644 index 0000000..cca022e --- /dev/null +++ b/src/graph/ops_builders/convolution.rs @@ -0,0 +1,45 @@ +use { + crate::graph::{ + op::{OneDNNGraphOp, OneDNNGraphOpType}, + spec::{OpSpec, RequiredAttrs}, + }, + onednnl_sys::dnnl_graph_op_attr_t::{ + self, dnnl_graph_op_attr_dilations, dnnl_graph_op_attr_pads_begin, + dnnl_graph_op_attr_pads_end, dnnl_graph_op_attr_strides, + }, +}; + +pub struct ConvolutionSpec; + +impl ConvolutionSpec { + pub const GROUPS: dnnl_graph_op_attr_t::Type = dnnl_graph_op_attr_t::dnnl_graph_op_attr_groups; + pub const AUTO_PAD: dnnl_graph_op_attr_t::Type = + dnnl_graph_op_attr_t::dnnl_graph_op_attr_auto_pad; + pub const DATA_FORMAT: dnnl_graph_op_attr_t::Type = + dnnl_graph_op_attr_t::dnnl_graph_op_attr_data_format; + pub const WEIGHTS_FORMAT: dnnl_graph_op_attr_t::Type = + dnnl_graph_op_attr_t::dnnl_graph_op_attr_weights_format; +} + +impl OpSpec for ConvolutionSpec { + const KIND: OneDNNGraphOpType = OneDNNGraphOp::CONVOLUTION; +} + +#[derive(Debug, Clone)] +pub struct ConvolutionAttrs { + pub strides: Vec, + pub pads_begin: Vec, + pub pads_end: Vec, + pub dilations: Vec, +} + +impl From for RequiredAttrs { + fn from(attrs: ConvolutionAttrs) -> Self { + RequiredAttrs::Some(vec![ + (dnnl_graph_op_attr_strides, attrs.strides.into()), + (dnnl_graph_op_attr_pads_begin, attrs.pads_begin.into()), + (dnnl_graph_op_attr_pads_end, attrs.pads_end.into()), + (dnnl_graph_op_attr_dilations, attrs.dilations.into()), + ]) + } +} diff --git a/src/graph/ops_builders/elu.rs b/src/graph/ops_builders/elu.rs new file mode 100644 index 0000000..7b68e75 --- /dev/null +++ b/src/graph/ops_builders/elu.rs @@ -0,0 +1,25 @@ +use onednnl_sys::{ + dnnl_graph_op_attr_t::dnnl_graph_op_attr_epsilon, dnnl_graph_op_kind_t::dnnl_graph_op_elu, +}; + +use crate::graph::{ + op::OneDNNGraphOpType, + spec::{OpSpec, RequiredAttrs}, +}; + +pub struct EluSpec; + +#[derive(Debug, Clone, Copy)] +pub struct EluAttrs { + pub alpha: f32, +} + +impl OpSpec for EluSpec { + const KIND: OneDNNGraphOpType = dnnl_graph_op_elu; +} + +impl From for RequiredAttrs { + fn from(attrs: EluAttrs) -> Self { + RequiredAttrs::Some(vec![(dnnl_graph_op_attr_epsilon, vec![attrs.alpha].into())]) + } +} diff --git a/src/graph/ops_builders/end.rs b/src/graph/ops_builders/end.rs new file mode 100644 index 0000000..d112751 --- /dev/null +++ b/src/graph/ops_builders/end.rs @@ -0,0 +1,8 @@ +use crate::graph::spec::OpSpec; + +pub struct EndSpec; + +impl OpSpec for EndSpec { + const KIND: onednnl_sys::dnnl_graph_op_kind_t::Type = + onednnl_sys::dnnl_graph_op_kind_t::dnnl_graph_op_end; +} diff --git a/src/graph/ops_builders/exp.rs b/src/graph/ops_builders/exp.rs new file mode 100644 index 0000000..0a43b04 --- /dev/null +++ b/src/graph/ops_builders/exp.rs @@ -0,0 +1,10 @@ +use { + crate::graph::{op::OneDNNGraphOpType, spec::OpSpec}, + onednnl_sys::dnnl_graph_op_kind_t::dnnl_graph_op_exp, +}; + +pub struct ExpSpec; + +impl OpSpec for ExpSpec { + const KIND: OneDNNGraphOpType = dnnl_graph_op_exp; +} diff --git a/src/graph/ops_builders/matmul.rs b/src/graph/ops_builders/matmul.rs new file mode 100644 index 0000000..9a8d2c0 --- /dev/null +++ b/src/graph/ops_builders/matmul.rs @@ -0,0 +1,41 @@ +use onednnl_sys::dnnl_graph_op_attr_t::{ + dnnl_graph_op_attr_transpose_a, dnnl_graph_op_attr_transpose_b, +}; + +use crate::graph::{ + op::{OneDNNGraphOp, OneDNNGraphOpType}, + ops_builders::OpAttrKind, + spec::{OpSpec, RequiredAttrs}, +}; + +pub struct MatMulSpec; + +impl OpSpec for MatMulSpec { + const KIND: OneDNNGraphOpType = OneDNNGraphOp::MATMUL; +} + +impl MatMulSpec { + pub const TRANSPOSE_A: OpAttrKind = dnnl_graph_op_attr_transpose_a; + pub const TRANSPOSE_B: OpAttrKind = dnnl_graph_op_attr_transpose_b; +} + +#[derive(Debug, Clone, Copy)] +pub struct MatMulAttrs { + pub transpose_a: bool, + pub transpose_b: bool, +} + +impl From for RequiredAttrs { + fn from(attrs: MatMulAttrs) -> Self { + RequiredAttrs::Some(vec![ + ( + dnnl_graph_op_attr_transpose_a, + vec![attrs.transpose_a as u8].into(), + ), + ( + dnnl_graph_op_attr_transpose_b, + vec![attrs.transpose_b as u8].into(), + ), + ]) + } +} diff --git a/src/graph/ops_builders/reorder.rs b/src/graph/ops_builders/reorder.rs new file mode 100644 index 0000000..541d2af --- /dev/null +++ b/src/graph/ops_builders/reorder.rs @@ -0,0 +1,10 @@ +use crate::graph::{ + op::{OneDNNGraphOp, OneDNNGraphOpType}, + spec::OpSpec, +}; + +pub struct ReorderSpec; + +impl OpSpec for ReorderSpec { + const KIND: OneDNNGraphOpType = OneDNNGraphOp::REORDER; +} diff --git a/src/graph/ops_builders/softmax.rs b/src/graph/ops_builders/softmax.rs new file mode 100644 index 0000000..a2ea0a5 --- /dev/null +++ b/src/graph/ops_builders/softmax.rs @@ -0,0 +1,16 @@ +use onednnl_sys::dnnl_graph_op_attr_t; + +use crate::graph::{ + op::{OneDNNGraphOp, OneDNNGraphOpType}, + spec::OpSpec, +}; + +pub struct SoftmaxSpec; + +impl OpSpec for SoftmaxSpec { + const KIND: OneDNNGraphOpType = OneDNNGraphOp::SOFTMAX; +} + +impl SoftmaxSpec { + pub const AXIS: dnnl_graph_op_attr_t::Type = dnnl_graph_op_attr_t::dnnl_graph_op_attr_axis; +} diff --git a/src/graph/ops_builders/static_reshape.rs b/src/graph/ops_builders/static_reshape.rs new file mode 100644 index 0000000..1908a79 --- /dev/null +++ b/src/graph/ops_builders/static_reshape.rs @@ -0,0 +1,31 @@ +use onednnl_sys::dnnl_graph_op_attr_t::{ + dnnl_graph_op_attr_shape, dnnl_graph_op_attr_special_zero, +}; + +use crate::graph::{ + op::{OneDNNGraphOp, OneDNNGraphOpType}, + spec::{OpSpec, RequiredAttrs}, +}; + +pub struct StaticReshapeSpec; + +impl OpSpec for StaticReshapeSpec { + const KIND: OneDNNGraphOpType = OneDNNGraphOp::STATIC_RESHAPE; +} + +pub struct StaticReshapeAttrs { + pub shape: Vec, + pub special_zero: u8, +} + +impl From for RequiredAttrs { + fn from(attrs: StaticReshapeAttrs) -> Self { + RequiredAttrs::Some(vec![ + (dnnl_graph_op_attr_shape, attrs.shape.into()), + ( + dnnl_graph_op_attr_special_zero, + vec![attrs.special_zero].into(), + ), + ]) + } +} diff --git a/src/graph/partition.rs b/src/graph/partition.rs new file mode 100644 index 0000000..89cd5c9 --- /dev/null +++ b/src/graph/partition.rs @@ -0,0 +1,162 @@ +use { + crate::{ + engine::Engine, + error::DnnlError, + graph::{ + compiled_partition::CompiledPartition, op::OneDNNGraphOp, + tensor::logical::LogicalTensor, + }, + }, + onednnl_sys::{ + dnnl_engine_kind_t, dnnl_graph_logical_tensor_t, dnnl_graph_partition_compile, + dnnl_graph_partition_create_with_op, dnnl_graph_partition_destroy, + dnnl_graph_partition_get_id, dnnl_graph_partition_get_input_ports, + dnnl_graph_partition_get_input_ports_num, dnnl_graph_partition_get_output_ports, + dnnl_graph_partition_get_output_ports_num, dnnl_graph_partition_is_supported, + dnnl_graph_partition_t, dnnl_status_t, + }, +}; + +pub struct OneDNNGraphPartition { + pub(crate) handle: dnnl_graph_partition_t, +} + +impl OneDNNGraphPartition { + pub fn create(engine: dnnl_engine_kind_t::Type, op: &OneDNNGraphOp) -> Result { + let mut handle = std::ptr::null_mut(); + + let status = unsafe { + dnnl_graph_partition_create_with_op(&mut handle, op.handle as *const _, engine) + }; + if status != dnnl_status_t::dnnl_success { + Err(status.into()) + } else { + Ok(Self { handle }) + } + } + + pub fn id(&self) -> Result { + let mut output = 0; + let status = unsafe { dnnl_graph_partition_get_id(self.handle, &mut output) }; + if status != dnnl_status_t::dnnl_success { + Err(status.into()) + } else { + Ok(output) + } + } + + pub fn get_output_ports_num(&self) -> Result { + let mut num = 0; + + let status = unsafe { dnnl_graph_partition_get_output_ports_num(self.handle, &mut num) }; + if status != dnnl_status_t::dnnl_success { + Err(status.into()) + } else { + Ok(num) + } + } + + pub fn get_input_ports_num(&self) -> Result { + let mut num = 0; + + let status = unsafe { dnnl_graph_partition_get_input_ports_num(self.handle, &mut num) }; + if status != dnnl_status_t::dnnl_success { + return Err(status.into()); + } + Ok(num) + } + + pub fn get_input_ports(&self) -> Result, DnnlError> { + let num = self.get_input_ports_num()? as usize; + + // 1) Reserve space for `num` dnnl_graph_logical_tensor_t values + let mut raw_ports = Vec::::with_capacity(num); + + // 2) Call the C API to fill them in + let status = unsafe { + dnnl_graph_partition_get_input_ports(self.handle, num, raw_ports.as_mut_ptr()) + }; + if status != dnnl_status_t::dnnl_success { + return Err(status.into()); + } + + // 3) *now* those slots are initialized—tell Rust how many valid entries there are + unsafe { raw_ports.set_len(num) }; + + // 4) Wrap each one in your safe type + Ok(raw_ports + .into_iter() + .map(|handle| LogicalTensor { handle }) + .collect()) + } + + pub fn get_output_ports(&self) -> Result, DnnlError> { + let num = self.get_output_ports_num()? as usize; + let mut raw_ports = Vec::::with_capacity(num); + + let status = unsafe { + dnnl_graph_partition_get_output_ports(self.handle, num, raw_ports.as_mut_ptr()) + }; + if status != dnnl_status_t::dnnl_success { + return Err(status.into()); + } + + unsafe { raw_ports.set_len(num) }; + Ok(raw_ports + .into_iter() + .map(|handle| LogicalTensor { handle }) + .collect()) + } + + pub fn is_supported(&self) -> Result { + let mut supported = 0; + let status = unsafe { dnnl_graph_partition_is_supported(self.handle, &mut supported) }; + if status != dnnl_status_t::dnnl_success { + return Err(status.into()); + } + Ok(supported != 0) + } + + pub fn compile(self, engine: &Engine) -> Result { + let in_num = self.get_input_ports_num()?; + let out_num = self.get_output_ports_num()?; + let input_logical_tensors = self.get_input_ports()?; + + let output_logical_tensors = self.get_output_ports()?; + + let mut inputs = input_logical_tensors + .iter() + .map(|e| &e.handle as *const dnnl_graph_logical_tensor_t) + .collect::>(); + let mut outputs = output_logical_tensors + .iter() + .map(|e| &e.handle as *const dnnl_graph_logical_tensor_t) + .collect::>(); + + let cp = CompiledPartition::create(self)?; + + let status = unsafe { + dnnl_graph_partition_compile( + cp.partition.handle, + cp.handle, + in_num, + inputs.as_mut_ptr(), + out_num, + outputs.as_mut_ptr(), + engine.handle, + ) + }; + if status != dnnl_status_t::dnnl_success { + return Err(status.into()); + } + Ok(cp) + } +} + +impl Drop for OneDNNGraphPartition { + fn drop(&mut self) { + unsafe { + dnnl_graph_partition_destroy(self.handle); + } + } +} diff --git a/src/graph/spec.rs b/src/graph/spec.rs new file mode 100644 index 0000000..59ae8db --- /dev/null +++ b/src/graph/spec.rs @@ -0,0 +1,42 @@ +use onednnl_sys::{dnnl_graph_op_attr_t, dnnl_graph_op_kind_t}; + +pub trait OpSpec { + const KIND: dnnl_graph_op_kind_t::Type; +} + +pub enum RequiredAttrs { + None, + Some(Vec<(dnnl_graph_op_attr_t::Type, AttrValue)>), +} + +#[derive(Debug, Clone)] +pub enum AttrValue { + Bool(Vec), + Int(Vec), + Float(Vec), + Str(String), +} + +impl From> for AttrValue { + fn from(value: Vec) -> Self { + AttrValue::Bool(value) + } +} + +impl From> for AttrValue { + fn from(value: Vec) -> Self { + AttrValue::Int(value) + } +} + +impl From> for AttrValue { + fn from(value: Vec) -> Self { + AttrValue::Float(value) + } +} + +impl From for AttrValue { + fn from(value: String) -> Self { + AttrValue::Str(value) + } +} diff --git a/src/graph/tensor.rs b/src/graph/tensor.rs new file mode 100644 index 0000000..39ee1dd --- /dev/null +++ b/src/graph/tensor.rs @@ -0,0 +1,2 @@ +pub mod logical; +pub mod tensor; diff --git a/src/graph/tensor/logical.rs b/src/graph/tensor/logical.rs new file mode 100644 index 0000000..6ac50c9 --- /dev/null +++ b/src/graph/tensor/logical.rs @@ -0,0 +1,80 @@ +use { + crate::error::DnnlError, + onednnl_sys::{ + dnnl_data_type_t, dnnl_dims_t, dnnl_graph_layout_type_t, dnnl_graph_logical_tensor_init, + dnnl_graph_logical_tensor_init_with_dims, dnnl_graph_logical_tensor_t, + dnnl_graph_tensor_property_t, dnnl_status_t, + }, + std::mem::MaybeUninit, +}; + +pub struct LogicalTensor { + pub(crate) handle: onednnl_sys::dnnl_graph_logical_tensor_t, +} + +impl LogicalTensor { + pub fn create( + tid: usize, + dtype: dnnl_data_type_t::Type, + ndims: i32, + layout: dnnl_graph_layout_type_t::Type, + property: dnnl_graph_tensor_property_t::Type, + ) -> Result { + // allocate uninitialized LT + let mut lt = MaybeUninit::::uninit(); + let status = unsafe { + dnnl_graph_logical_tensor_init(lt.as_mut_ptr(), tid, dtype, ndims, layout, property) + }; + if status != dnnl_status_t::dnnl_success { + return Err(status.into()); + } + // assume_init is now safe + Ok(LogicalTensor { + handle: unsafe { lt.assume_init() }, + }) + } + + pub fn id(&self) -> usize { + self.handle.id + } + + pub fn new_with_dims( + tid: usize, + dtype: dnnl_data_type_t::Type, + dims: &[i64], + layout: dnnl_graph_layout_type_t::Type, + property: dnnl_graph_tensor_property_t::Type, + ) -> Result { + let ndims = dims.len() as i32; + let mut c_dims: dnnl_dims_t = [0; 12]; + for (i, &dim) in dims.iter().enumerate() { + c_dims[i] = dim; + } + + let mut lt = MaybeUninit::::uninit(); + let status = unsafe { + dnnl_graph_logical_tensor_init_with_dims( + lt.as_mut_ptr(), + tid, + dtype, + ndims, + c_dims.as_ptr(), + layout, + property, + ) + }; + if status != dnnl_status_t::dnnl_success { + return Err(status.into()); + } + Ok(LogicalTensor { + handle: unsafe { lt.assume_init() }, + }) + } + + pub fn get_dims(&self) -> Vec { + let ndims = self.handle.ndims; + let dims = self.handle.dims; + + dims.iter().take(ndims as usize).map(|&dim| dim).collect() + } +} diff --git a/src/graph/tensor/tensor.rs b/src/graph/tensor/tensor.rs new file mode 100644 index 0000000..14a5f0b --- /dev/null +++ b/src/graph/tensor/tensor.rs @@ -0,0 +1,87 @@ +use { + crate::{ + engine, error::DnnlError, graph::tensor::logical::LogicalTensor, + memory::DNNL_MEMORY_ALLOCATE, + }, + onednnl_sys::{ + dnnl_graph_logical_tensor_t, dnnl_graph_tensor_create, dnnl_graph_tensor_destroy, + dnnl_graph_tensor_get_data_handle, dnnl_graph_tensor_t, dnnl_status_t, + }, + std::os::raw::c_void, +}; + +#[derive(Debug, Clone)] +pub struct Tensor { + pub(crate) handle: dnnl_graph_tensor_t, +} + +impl Tensor { + pub fn new( + logical_tensor: &LogicalTensor, + engine: &engine::Engine, + data: &[f32], + ) -> Result { + let mut handle = std::ptr::null_mut(); + let status = unsafe { + dnnl_graph_tensor_create( + &mut handle, + &logical_tensor.handle as *const dnnl_graph_logical_tensor_t, + engine.handle, + data.as_ptr() as *mut c_void, + ) + }; + if status != dnnl_status_t::dnnl_success { + Err(status.into()) + } else { + Ok(Self { handle }) + } + } + + pub fn new_library_allocated( + logical_tensor: &LogicalTensor, + engine: &engine::Engine, + ) -> Result { + let mut handle = std::ptr::null_mut(); + let status = unsafe { + dnnl_graph_tensor_create( + &mut handle, + &logical_tensor.handle as *const dnnl_graph_logical_tensor_t, + engine.handle, + DNNL_MEMORY_ALLOCATE, + ) + }; + if status != dnnl_status_t::dnnl_success { + Err(status.into()) + } else { + Ok(Self { handle }) + } + } + + pub fn get_data_handle(&self, size: usize) -> Result, DnnlError> { + let mut data_handle: *mut c_void = std::ptr::null_mut(); + + let status = unsafe { dnnl_graph_tensor_get_data_handle(self.handle, &mut data_handle) }; + + if status != dnnl_status_t::dnnl_success { + return Err(status.into()); + } + + // 4. Allocate a Rust Vec to copy the data into. + let mut rust_buffer = vec![0.0f32; size]; + + unsafe { + std::ptr::copy_nonoverlapping( + data_handle as *const f32, + rust_buffer.as_mut_ptr(), + size, + ); + } + Ok(rust_buffer) + } +} + +impl Drop for Tensor { + fn drop(&mut self) { + unsafe { dnnl_graph_tensor_destroy(self.handle) }; + } +} diff --git a/src/lib.rs b/src/lib.rs index 1bc7cef..c136d77 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod engine; pub mod error; +pub mod graph; pub mod memory; pub mod primitive; pub mod primitives; diff --git a/src/memory.rs b/src/memory.rs index cf48d12..0213e3c 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -19,10 +19,10 @@ pub fn data_type_size(ty: dnnl_data_type_t::Type) -> usize { } /// Memory without an underlying buffer -const DNNL_MEMORY_NONE: *mut c_void = std::ptr::null_mut(); +pub const DNNL_MEMORY_NONE: *mut c_void = std::ptr::null_mut(); /// Memory with library allocated buffer -const DNNL_MEMORY_ALLOCATE: *mut c_void = (usize::MAX) as *mut c_void; +pub const DNNL_MEMORY_ALLOCATE: *mut c_void = (usize::MAX) as *mut c_void; pub mod buffer; pub mod descriptor;