Skip to content

Commit 423f587

Browse files
authored
Merge pull request #21 from boydjohnson/feature/more-features
Have MemoryDescriptors owned by PrimitiveConfigs/PrimitiveDescriptor
2 parents 4227752 + f023435 commit 423f587

File tree

15 files changed

+461
-270
lines changed

15 files changed

+461
-270
lines changed

benches/binary.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,17 @@ fn binary_add(b: &mut Bencher) {
3636

3737
let binary_config = ForwardBinaryConfig {
3838
alg_kind: Binary::ADD,
39-
src0_desc: &src0_desc,
40-
src1_desc: &src1_desc,
41-
dst_desc: &dst_desc,
42-
attr: &PrimitiveAttributes::new().unwrap(),
39+
src0_desc: src0_desc.clone_desc().unwrap(),
40+
src1_desc: src1_desc.clone_desc().unwrap(),
41+
dst_desc: dst_desc.clone_desc().unwrap(),
42+
attr: PrimitiveAttributes::new().unwrap(),
4343
};
4444

45-
let primitive =
46-
Primitive::new::<_, PropForwardInference, ForwardBinary<_>>(binary_config, engine.clone());
45+
let primitive = Primitive::<_, PropForwardInference, ForwardBinaryConfig>::new::<
46+
ForwardBinary<_>,
47+
>(binary_config, engine.clone());
4748
assert!(primitive.is_ok());
48-
let primitive = primitive.unwrap();
49+
let mut primitive = primitive.unwrap();
4950

5051
let s0_buffer = AlignedBuffer::new(&[4.0f32, 5.0, 6.0]).unwrap().into();
5152

@@ -87,7 +88,7 @@ fn binary_add(b: &mut Bencher) {
8788

8889
assert!(stream.wait().is_ok());
8990

90-
assert_eq!(result, Ok(()));
91+
assert!(result.is_ok());
9192

9293
assert_eq!(dst_memory.to_vec(), Ok(vec![5.0, 7.0, 9.0]));
9394
});

src/primitive.rs

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,13 @@ impl PropType<Backward> for PropBackwardData {
108108
const KIND: dnnl_prop_kind_t::Type = dnnl_prop_kind_t::dnnl_backward_data;
109109
}
110110

111-
pub struct Primitive {
111+
pub struct Primitive<'a, D: Direction, P: PropType<D>, C: PrimitiveConfig<'a, D, P>> {
112112
pub handle: dnnl_primitive_t,
113-
pub desc: PrimitiveDescriptor,
113+
pub desc: Option<PrimitiveDescriptor<'a, D, P, C>>,
114114
pub engine: Arc<Engine>,
115115
}
116116

117-
impl Primitive {
117+
impl<'a, D: Direction, P: PropType<D>, C: PrimitiveConfig<'a, D, P>> Primitive<'a, D, P, C> {
118118
/// Creates a new `Primitive`.
119119
///
120120
/// # Example
@@ -141,45 +141,50 @@ impl Primitive {
141141
/// // Define a forward binary config
142142
/// let binary_config = ForwardBinaryConfig {
143143
/// alg_kind: dnnl_alg_kind_t::dnnl_binary_add, // Example: addition operation
144-
/// src0_desc: &src0_desc,
145-
/// src1_desc: &src1_desc,
146-
/// dst_desc: &dst_desc,
147-
/// attr: &PrimitiveAttributes::new().unwrap(),
144+
/// src0_desc: src0_desc,
145+
/// src1_desc: src1_desc,
146+
/// dst_desc: dst_desc,
147+
/// attr: PrimitiveAttributes::new().unwrap(),
148148
/// };
149149
///
150-
/// let primitive =
151-
/// Primitive::new::<_, PropForwardInference, ForwardBinary<_>>(binary_config, engine);
150+
/// let primitive = Primitive::<_, PropForwardInference, ForwardBinaryConfig>::new::<
151+
/// ForwardBinary<_>,
152+
/// >(binary_config, engine);
152153
///
153154
/// assert!(primitive.is_ok());
154155
/// ```
155-
pub fn new<'a, D: Direction, P: PropType<D>, O: Operation<'a, D, P>>(
156+
pub fn new<O: Operation<'a, D, P, OperationConfig = C>>(
156157
config: O::OperationConfig,
157158
engine: Arc<Engine>,
158-
) -> Result<Primitive, DnnlError> {
159+
) -> Result<Primitive<'a, D, P, C>, DnnlError> {
159160
let desc = config.create_primitive_desc(engine.clone())?;
160161
Self::from_descriptor(desc, engine)
161162
}
162163

163164
pub fn from_descriptor(
164-
desc: PrimitiveDescriptor,
165+
desc: PrimitiveDescriptor<'a, D, P, C>,
165166
engine: Arc<Engine>,
166-
) -> Result<Primitive, DnnlError> {
167+
) -> Result<Primitive<'a, D, P, C>, DnnlError> {
167168
let mut handle = std::ptr::null_mut();
168169

169170
let status = unsafe { dnnl_primitive_create(&mut handle, desc.handle) };
170171

171172
if status == dnnl_status_t::dnnl_success {
172-
Ok(Self {
173+
Ok(Primitive::<'a, D, P, C> {
173174
handle,
174-
desc,
175+
desc: Some(desc),
175176
engine,
176177
})
177178
} else {
178179
Err(status.into())
179180
}
180181
}
181182

182-
pub fn execute<T>(&self, stream: &Stream, args: Vec<ExecArg<'_, T>>) -> Result<(), DnnlError> {
183+
pub fn execute<T>(
184+
&mut self,
185+
stream: &Stream,
186+
args: Vec<ExecArg<'_, T>>,
187+
) -> Result<Option<PrimitiveDescriptor<'a, D, P, C>>, DnnlError> {
183188
let c_args: Vec<dnnl_exec_arg_t> = args
184189
.iter()
185190
.map(|arg| dnnl_exec_arg_t {
@@ -198,14 +203,16 @@ impl Primitive {
198203
};
199204

200205
if status == dnnl_status_t::dnnl_success {
201-
Ok(())
206+
Ok(self.desc.take())
202207
} else {
203208
Err(status.into())
204209
}
205210
}
206211
}
207212

208-
impl Drop for Primitive {
213+
impl<'a, D: Direction, P: PropType<D>, C: PrimitiveConfig<'a, D, P>> Drop
214+
for Primitive<'a, D, P, C>
215+
{
209216
fn drop(&mut self) {
210217
unsafe {
211218
dnnl_primitive_destroy(self.handle);

src/primitive/config.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ use {
44
std::sync::Arc,
55
};
66

7-
pub trait PrimitiveConfig<'a, D: Direction, P: PropType<D>> {
8-
fn create_primitive_desc(&self, engine: Arc<Engine>) -> Result<PrimitiveDescriptor, DnnlError>;
7+
pub trait PrimitiveConfig<'a, D: Direction, P: PropType<D>>: Sized {
8+
fn create_primitive_desc(
9+
self,
10+
engine: Arc<Engine>,
11+
) -> Result<PrimitiveDescriptor<'a, D, P, Self>, DnnlError>;
912
}

src/primitive/descriptor.rs

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,26 @@ use {
22
super::{config::PrimitiveConfig, Direction, Operation, PropType},
33
crate::{engine::Engine, error::DnnlError},
44
onednnl_sys::{dnnl_primitive_desc_destroy, dnnl_primitive_desc_t},
5-
std::sync::Arc,
5+
std::{marker::PhantomData, sync::Arc},
66
};
77

8-
pub struct PrimitiveDescriptor {
8+
pub struct PrimitiveDescriptor<
9+
'a,
10+
D: Direction,
11+
P: PropType<D>,
12+
C: PrimitiveConfig<'a, D, P> + Sized,
13+
> {
914
pub handle: dnnl_primitive_desc_t,
15+
pub config: C,
16+
17+
pub(crate) _marker_a: PhantomData<&'a ()>,
18+
pub(crate) _marker_d: PhantomData<D>,
19+
pub(crate) _marker_p: PhantomData<P>,
1020
}
1121

12-
impl PrimitiveDescriptor {
22+
impl<'a, D: Direction, P: PropType<D>, C: PrimitiveConfig<'a, D, P> + Sized>
23+
PrimitiveDescriptor<'a, D, P, C>
24+
{
1325
/// Creates a new `PrimitiveDescriptor`.
1426
///
1527
/// # Example
@@ -37,30 +49,30 @@ impl PrimitiveDescriptor {
3749
/// // Define a forward binary config
3850
/// let binary_config = ForwardBinaryConfig {
3951
/// alg_kind: dnnl_alg_kind_t::dnnl_binary_add, // Example: addition operation
40-
/// src0_desc: &src0_desc,
41-
/// src1_desc: &src1_desc,
42-
/// dst_desc: &dst_desc,
43-
/// attr: &PrimitiveAttributes::new().unwrap(),
52+
/// src0_desc: src0_desc,
53+
/// src1_desc: src1_desc,
54+
/// dst_desc: dst_desc,
55+
/// attr: PrimitiveAttributes::new().unwrap(),
4456
/// };
4557
///
4658
/// // Create a new PrimitiveDescriptor for the forward binary operation
47-
/// let primitive_descriptor = PrimitiveDescriptor::new::<
48-
/// Forward,
49-
/// _,
59+
/// let primitive_descriptor = PrimitiveDescriptor::<_, _, ForwardBinaryConfig>::new::<
5060
/// ForwardBinary<PropForwardInference>,
5161
/// >(binary_config, engine);
5262
///
5363
/// assert!(primitive_descriptor.is_ok());
5464
/// ```
55-
pub fn new<'a, D: Direction, P: PropType<D>, O: Operation<'a, D, P>>(
65+
pub fn new<O: Operation<'a, D, P, OperationConfig = C>>(
5666
config: O::OperationConfig,
5767
engine: Arc<Engine>,
58-
) -> Result<Self, DnnlError> {
68+
) -> Result<PrimitiveDescriptor<'a, D, P, C>, DnnlError> {
5969
config.create_primitive_desc(engine)
6070
}
6171
}
6272

63-
impl Drop for PrimitiveDescriptor {
73+
impl<'a, D: Direction, P: PropType<D>, C: PrimitiveConfig<'a, D, P>> Drop
74+
for PrimitiveDescriptor<'a, D, P, C>
75+
{
6476
fn drop(&mut self) {
6577
unsafe { dnnl_primitive_desc_destroy(self.handle) };
6678
}

src/primitives/au_gru.rs

Lines changed: 61 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,36 @@ use {
55
memory::descriptor::MemoryDescriptor,
66
primitive::{
77
attributes::PrimitiveAttributes, config::PrimitiveConfig,
8-
descriptor::PrimitiveDescriptor, Backward, Forward, Operation, OperationType, PropType,
8+
descriptor::PrimitiveDescriptor, Backward, Forward, Operation, OperationType,
9+
PropForwardTraining, PropType,
910
},
1011
},
1112
onednnl_sys::{
1213
dnnl_augru_backward_primitive_desc_create, dnnl_augru_forward_primitive_desc_create,
13-
dnnl_primitive_attr_t, dnnl_rnn_direction_t, dnnl_status_t,
14+
dnnl_rnn_direction_t, dnnl_status_t,
1415
},
15-
std::{ffi::c_uint, sync::Arc},
16+
std::{ffi::c_uint, marker::PhantomData, sync::Arc},
1617
};
1718

18-
pub struct ForwardAuGruConfig<'a> {
19+
pub struct ForwardAuGruConfig {
1920
direction: dnnl_rnn_direction_t::Type,
20-
src_layer_desc: &'a MemoryDescriptor,
21-
src_iter_desc: &'a MemoryDescriptor,
22-
attention_desc: &'a MemoryDescriptor,
23-
weights_layer_desc: &'a MemoryDescriptor,
24-
weights_iter_desc: &'a MemoryDescriptor,
25-
bias_desc: &'a MemoryDescriptor,
26-
dst_layer_desc: &'a MemoryDescriptor,
27-
dst_iter_desc: &'a MemoryDescriptor,
21+
src_layer_desc: MemoryDescriptor,
22+
src_iter_desc: MemoryDescriptor,
23+
attention_desc: MemoryDescriptor,
24+
weights_layer_desc: MemoryDescriptor,
25+
weights_iter_desc: MemoryDescriptor,
26+
bias_desc: MemoryDescriptor,
27+
dst_layer_desc: MemoryDescriptor,
28+
dst_iter_desc: MemoryDescriptor,
2829
flags: c_uint,
29-
attr: &'a PrimitiveAttributes,
30+
attr: PrimitiveAttributes,
3031
}
3132

32-
impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardAuGruConfig<'a> {
33-
fn create_primitive_desc(&self, engine: Arc<Engine>) -> Result<PrimitiveDescriptor, DnnlError> {
33+
impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardAuGruConfig {
34+
fn create_primitive_desc(
35+
self,
36+
engine: Arc<Engine>,
37+
) -> Result<PrimitiveDescriptor<'a, Forward, P, ForwardAuGruConfig>, DnnlError> {
3438
let mut handle = std::ptr::null_mut();
3539
let status = unsafe {
3640
dnnl_augru_forward_primitive_desc_create(
@@ -52,7 +56,14 @@ impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardAuGruC
5256
};
5357

5458
if status == dnnl_status_t::dnnl_success {
55-
Ok(PrimitiveDescriptor { handle })
59+
Ok(PrimitiveDescriptor::<'a, Forward, P, ForwardAuGruConfig> {
60+
handle,
61+
config: self,
62+
63+
_marker_a: PhantomData,
64+
_marker_d: PhantomData,
65+
_marker_p: PhantomData,
66+
})
5667
} else {
5768
Err(status.into())
5869
}
@@ -61,29 +72,32 @@ impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardAuGruC
6172

6273
pub struct BackwardAuGruConfig<'a> {
6374
direction: dnnl_rnn_direction_t::Type,
64-
src_layer_desc: &'a MemoryDescriptor,
65-
src_iter_desc: &'a MemoryDescriptor,
66-
attention_desc: &'a MemoryDescriptor,
67-
weights_layer_desc: &'a MemoryDescriptor,
68-
weights_iter_desc: &'a MemoryDescriptor,
69-
bias_desc: &'a MemoryDescriptor,
70-
dst_layer_desc: &'a MemoryDescriptor,
71-
dst_iter_desc: &'a MemoryDescriptor,
72-
diff_src_layer_desc: &'a MemoryDescriptor,
73-
diff_src_iter_desc: &'a MemoryDescriptor,
74-
diff_attention_desc: &'a MemoryDescriptor,
75-
diff_weights_layer_desc: &'a MemoryDescriptor,
76-
diff_weights_iter_desc: &'a MemoryDescriptor,
77-
diff_bias_desc: &'a MemoryDescriptor,
78-
diff_dst_layer_desc: &'a MemoryDescriptor,
79-
diff_dst_iter_desc: &'a MemoryDescriptor,
75+
src_layer_desc: MemoryDescriptor,
76+
src_iter_desc: MemoryDescriptor,
77+
attention_desc: MemoryDescriptor,
78+
weights_layer_desc: MemoryDescriptor,
79+
weights_iter_desc: MemoryDescriptor,
80+
bias_desc: MemoryDescriptor,
81+
dst_layer_desc: MemoryDescriptor,
82+
dst_iter_desc: MemoryDescriptor,
83+
diff_src_layer_desc: MemoryDescriptor,
84+
diff_src_iter_desc: MemoryDescriptor,
85+
diff_attention_desc: MemoryDescriptor,
86+
diff_weights_layer_desc: MemoryDescriptor,
87+
diff_weights_iter_desc: MemoryDescriptor,
88+
diff_bias_desc: MemoryDescriptor,
89+
diff_dst_layer_desc: MemoryDescriptor,
90+
diff_dst_iter_desc: MemoryDescriptor,
8091
flags: c_uint,
81-
hint_fwd_pd: &'a PrimitiveDescriptor,
82-
attr: dnnl_primitive_attr_t,
92+
hint_fwd_pd: &'a PrimitiveDescriptor<'a, Forward, PropForwardTraining, ForwardAuGruConfig>,
93+
attr: PrimitiveAttributes,
8394
}
8495

8596
impl<'a, P: PropType<Backward>> PrimitiveConfig<'a, Backward, P> for BackwardAuGruConfig<'a> {
86-
fn create_primitive_desc(&self, engine: Arc<Engine>) -> Result<PrimitiveDescriptor, DnnlError> {
97+
fn create_primitive_desc(
98+
self,
99+
engine: Arc<Engine>,
100+
) -> Result<PrimitiveDescriptor<'a, Backward, P, BackwardAuGruConfig<'a>>, DnnlError> {
87101
let mut handle = std::ptr::null_mut();
88102
let status = unsafe {
89103
dnnl_augru_backward_primitive_desc_create(
@@ -109,12 +123,20 @@ impl<'a, P: PropType<Backward>> PrimitiveConfig<'a, Backward, P> for BackwardAuG
109123
self.diff_dst_iter_desc.handle,
110124
self.flags,
111125
self.hint_fwd_pd.handle,
112-
self.attr,
126+
self.attr.handle,
113127
)
114128
};
115129

116130
if status == dnnl_status_t::dnnl_success {
117-
Ok(PrimitiveDescriptor { handle })
131+
Ok(
132+
PrimitiveDescriptor::<'a, Backward, P, BackwardAuGruConfig<'a>> {
133+
handle,
134+
config: self,
135+
_marker_a: PhantomData,
136+
_marker_d: PhantomData,
137+
_marker_p: PhantomData,
138+
},
139+
)
118140
} else {
119141
Err(status.into())
120142
}
@@ -125,9 +147,9 @@ pub struct ForwardAuGru<P: PropType<Forward>> {
125147
pub prop_type: P,
126148
}
127149

128-
impl<'a, P: PropType<Forward>> Operation<'a, Forward, P> for ForwardAuGru<P> {
150+
impl<P: PropType<Forward>> Operation<'_, Forward, P> for ForwardAuGru<P> {
129151
const TYPE: OperationType = OperationType::Augru;
130-
type OperationConfig = ForwardAuGruConfig<'a>;
152+
type OperationConfig = ForwardAuGruConfig;
131153
}
132154

133155
pub struct BackwardAuGru<P: PropType<Backward>> {

0 commit comments

Comments
 (0)