55 memory:: descriptor:: MemoryDescriptor ,
66 primitive:: {
77 attributes:: PrimitiveAttributes , config:: PrimitiveConfig ,
8- descriptor:: PrimitiveDescriptor , Backward , Forward , Operation , OperationType , PropType ,
8+ descriptor:: PrimitiveDescriptor , Backward , Forward , Operation , OperationType ,
9+ PropForwardTraining , PropType ,
910 } ,
1011 } ,
1112 onednnl_sys:: {
1213 dnnl_augru_backward_primitive_desc_create, dnnl_augru_forward_primitive_desc_create,
13- dnnl_primitive_attr_t , dnnl_rnn_direction_t, dnnl_status_t,
14+ dnnl_rnn_direction_t, dnnl_status_t,
1415 } ,
15- std:: { ffi:: c_uint, sync:: Arc } ,
16+ std:: { ffi:: c_uint, marker :: PhantomData , sync:: Arc } ,
1617} ;
1718
18- pub struct ForwardAuGruConfig < ' a > {
19+ pub struct ForwardAuGruConfig {
1920 direction : dnnl_rnn_direction_t:: Type ,
20- src_layer_desc : & ' a MemoryDescriptor ,
21- src_iter_desc : & ' a MemoryDescriptor ,
22- attention_desc : & ' a MemoryDescriptor ,
23- weights_layer_desc : & ' a MemoryDescriptor ,
24- weights_iter_desc : & ' a MemoryDescriptor ,
25- bias_desc : & ' a MemoryDescriptor ,
26- dst_layer_desc : & ' a MemoryDescriptor ,
27- dst_iter_desc : & ' a MemoryDescriptor ,
21+ src_layer_desc : MemoryDescriptor ,
22+ src_iter_desc : MemoryDescriptor ,
23+ attention_desc : MemoryDescriptor ,
24+ weights_layer_desc : MemoryDescriptor ,
25+ weights_iter_desc : MemoryDescriptor ,
26+ bias_desc : MemoryDescriptor ,
27+ dst_layer_desc : MemoryDescriptor ,
28+ dst_iter_desc : MemoryDescriptor ,
2829 flags : c_uint ,
29- attr : & ' a PrimitiveAttributes ,
30+ attr : PrimitiveAttributes ,
3031}
3132
32- impl < ' a , P : PropType < Forward > > PrimitiveConfig < ' a , Forward , P > for ForwardAuGruConfig < ' a > {
33- fn create_primitive_desc ( & self , engine : Arc < Engine > ) -> Result < PrimitiveDescriptor , DnnlError > {
33+ impl < ' a , P : PropType < Forward > > PrimitiveConfig < ' a , Forward , P > for ForwardAuGruConfig {
34+ fn create_primitive_desc (
35+ self ,
36+ engine : Arc < Engine > ,
37+ ) -> Result < PrimitiveDescriptor < ' a , Forward , P , ForwardAuGruConfig > , DnnlError > {
3438 let mut handle = std:: ptr:: null_mut ( ) ;
3539 let status = unsafe {
3640 dnnl_augru_forward_primitive_desc_create (
@@ -52,7 +56,14 @@ impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardAuGruC
5256 } ;
5357
5458 if status == dnnl_status_t:: dnnl_success {
55- Ok ( PrimitiveDescriptor { handle } )
59+ Ok ( PrimitiveDescriptor :: < ' a , Forward , P , ForwardAuGruConfig > {
60+ handle,
61+ config : self ,
62+
63+ _marker_a : PhantomData ,
64+ _marker_d : PhantomData ,
65+ _marker_p : PhantomData ,
66+ } )
5667 } else {
5768 Err ( status. into ( ) )
5869 }
@@ -61,29 +72,32 @@ impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardAuGruC
6172
6273pub struct BackwardAuGruConfig < ' a > {
6374 direction : dnnl_rnn_direction_t:: Type ,
64- src_layer_desc : & ' a MemoryDescriptor ,
65- src_iter_desc : & ' a MemoryDescriptor ,
66- attention_desc : & ' a MemoryDescriptor ,
67- weights_layer_desc : & ' a MemoryDescriptor ,
68- weights_iter_desc : & ' a MemoryDescriptor ,
69- bias_desc : & ' a MemoryDescriptor ,
70- dst_layer_desc : & ' a MemoryDescriptor ,
71- dst_iter_desc : & ' a MemoryDescriptor ,
72- diff_src_layer_desc : & ' a MemoryDescriptor ,
73- diff_src_iter_desc : & ' a MemoryDescriptor ,
74- diff_attention_desc : & ' a MemoryDescriptor ,
75- diff_weights_layer_desc : & ' a MemoryDescriptor ,
76- diff_weights_iter_desc : & ' a MemoryDescriptor ,
77- diff_bias_desc : & ' a MemoryDescriptor ,
78- diff_dst_layer_desc : & ' a MemoryDescriptor ,
79- diff_dst_iter_desc : & ' a MemoryDescriptor ,
75+ src_layer_desc : MemoryDescriptor ,
76+ src_iter_desc : MemoryDescriptor ,
77+ attention_desc : MemoryDescriptor ,
78+ weights_layer_desc : MemoryDescriptor ,
79+ weights_iter_desc : MemoryDescriptor ,
80+ bias_desc : MemoryDescriptor ,
81+ dst_layer_desc : MemoryDescriptor ,
82+ dst_iter_desc : MemoryDescriptor ,
83+ diff_src_layer_desc : MemoryDescriptor ,
84+ diff_src_iter_desc : MemoryDescriptor ,
85+ diff_attention_desc : MemoryDescriptor ,
86+ diff_weights_layer_desc : MemoryDescriptor ,
87+ diff_weights_iter_desc : MemoryDescriptor ,
88+ diff_bias_desc : MemoryDescriptor ,
89+ diff_dst_layer_desc : MemoryDescriptor ,
90+ diff_dst_iter_desc : MemoryDescriptor ,
8091 flags : c_uint ,
81- hint_fwd_pd : & ' a PrimitiveDescriptor ,
82- attr : dnnl_primitive_attr_t ,
92+ hint_fwd_pd : & ' a PrimitiveDescriptor < ' a , Forward , PropForwardTraining , ForwardAuGruConfig > ,
93+ attr : PrimitiveAttributes ,
8394}
8495
8596impl < ' a , P : PropType < Backward > > PrimitiveConfig < ' a , Backward , P > for BackwardAuGruConfig < ' a > {
86- fn create_primitive_desc ( & self , engine : Arc < Engine > ) -> Result < PrimitiveDescriptor , DnnlError > {
97+ fn create_primitive_desc (
98+ self ,
99+ engine : Arc < Engine > ,
100+ ) -> Result < PrimitiveDescriptor < ' a , Backward , P , BackwardAuGruConfig < ' a > > , DnnlError > {
87101 let mut handle = std:: ptr:: null_mut ( ) ;
88102 let status = unsafe {
89103 dnnl_augru_backward_primitive_desc_create (
@@ -109,12 +123,20 @@ impl<'a, P: PropType<Backward>> PrimitiveConfig<'a, Backward, P> for BackwardAuG
109123 self . diff_dst_iter_desc . handle ,
110124 self . flags ,
111125 self . hint_fwd_pd . handle ,
112- self . attr ,
126+ self . attr . handle ,
113127 )
114128 } ;
115129
116130 if status == dnnl_status_t:: dnnl_success {
117- Ok ( PrimitiveDescriptor { handle } )
131+ Ok (
132+ PrimitiveDescriptor :: < ' a , Backward , P , BackwardAuGruConfig < ' a > > {
133+ handle,
134+ config : self ,
135+ _marker_a : PhantomData ,
136+ _marker_d : PhantomData ,
137+ _marker_p : PhantomData ,
138+ } ,
139+ )
118140 } else {
119141 Err ( status. into ( ) )
120142 }
@@ -125,9 +147,9 @@ pub struct ForwardAuGru<P: PropType<Forward>> {
125147 pub prop_type : P ,
126148}
127149
128- impl < ' a , P : PropType < Forward > > Operation < ' a , Forward , P > for ForwardAuGru < P > {
150+ impl < P : PropType < Forward > > Operation < ' _ , Forward , P > for ForwardAuGru < P > {
129151 const TYPE : OperationType = OperationType :: Augru ;
130- type OperationConfig = ForwardAuGruConfig < ' a > ;
152+ type OperationConfig = ForwardAuGruConfig ;
131153}
132154
133155pub struct BackwardAuGru < P : PropType < Backward > > {
0 commit comments