@@ -72,10 +72,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
72
72
auto dst_md = platform::MKLDNNMemDesc (
73
73
dst_tz, mkldnn::memory::data_type::f32 , mkldnn::memory::format::nchw);
74
74
75
- auto src_memory = mkldnn::memory ({src_md, mkldnn_engine},
76
- reinterpret_cast <void *>(input_data));
77
- auto weights_memory = mkldnn::memory ({weights_md, mkldnn_engine},
78
- reinterpret_cast <void *>(filter_data));
75
+ auto src_memory =
76
+ mkldnn::memory ({src_md, mkldnn_engine},
77
+ reinterpret_cast <void *>(const_cast <T*>(input_data)));
78
+ auto weights_memory =
79
+ mkldnn::memory ({weights_md, mkldnn_engine},
80
+ reinterpret_cast <void *>(const_cast <T*>(filter_data)));
79
81
auto dst_memory = mkldnn::memory ({dst_md, mkldnn_engine}, output_data);
80
82
81
83
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
@@ -180,9 +182,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
180
182
dst_tz, mkldnn::memory::data_type::f32 , mkldnn::memory::format::nchw);
181
183
182
184
// create memory
183
- auto diff_dst_memory =
184
- mkldnn::memory ( {diff_weights_md, mkldnn_engine},
185
- reinterpret_cast <void *>(output_grad_data));
185
+ auto diff_dst_memory = mkldnn::memory (
186
+ {diff_weights_md, mkldnn_engine},
187
+ reinterpret_cast <void *>(const_cast <T*>( output_grad_data) ));
186
188
// Retrieve conv_pd from device context
187
189
auto conv_pd =
188
190
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
@@ -202,8 +204,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
202
204
auto diff_weights_memory =
203
205
mkldnn::memory ({diff_weights_md, mkldnn_engine},
204
206
reinterpret_cast <void *>(filter_grad_data));
205
- auto src_memory = mkldnn::memory ({src_md, mkldnn_engine},
206
- reinterpret_cast <void *>(input_data));
207
+ auto src_memory =
208
+ mkldnn::memory ({src_md, mkldnn_engine},
209
+ reinterpret_cast <void *>(const_cast <T*>(input_data)));
207
210
208
211
// create backward conv primitive for weights
209
212
auto conv_bwd_weights_prim = mkldnn::convolution_backward_weights (
@@ -222,11 +225,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
222
225
strides, paddings, *conv_pd, mkldnn_engine);
223
226
224
227
// create memory
225
- auto diff_src_memory =
226
- mkldnn::memory ({diff_src_md, mkldnn_engine},
227
- reinterpret_cast <void *>(input_grad_data));
228
- auto weights_memory = mkldnn::memory (
229
- {weights_md, mkldnn_engine}, reinterpret_cast <void *>(filter_data));
228
+ auto diff_src_memory = mkldnn::memory (
229
+ {diff_src_md, mkldnn_engine},
230
+ reinterpret_cast <void *>(const_cast <T*>(input_grad_data)));
231
+ auto weights_memory =
232
+ mkldnn::memory ({weights_md, mkldnn_engine},
233
+ reinterpret_cast <void *>(const_cast <T*>(filter_data)));
230
234
231
235
// create backward conv primitive for data
232
236
auto conv_bwd_data_prim = mkldnn::convolution_backward_data (
0 commit comments