Skip to content

Commit ed681d5

Browse files
author
Abhinav Arora
committed
Fix conv_mkldnn_op.cc which is causing CI failure
1 parent 6f83142 commit ed681d5

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

paddle/fluid/operators/conv_mkldnn_op.cc

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
7272
auto dst_md = platform::MKLDNNMemDesc(
7373
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
7474

75-
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
76-
reinterpret_cast<void*>(input_data));
77-
auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine},
78-
reinterpret_cast<void*>(filter_data));
75+
auto src_memory =
76+
mkldnn::memory({src_md, mkldnn_engine},
77+
reinterpret_cast<void*>(const_cast<T*>(input_data)));
78+
auto weights_memory =
79+
mkldnn::memory({weights_md, mkldnn_engine},
80+
reinterpret_cast<void*>(const_cast<T*>(filter_data)));
7981
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data);
8082

8183
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
@@ -180,9 +182,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
180182
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
181183

182184
// create memory
183-
auto diff_dst_memory =
184-
mkldnn::memory({diff_weights_md, mkldnn_engine},
185-
reinterpret_cast<void*>(output_grad_data));
185+
auto diff_dst_memory = mkldnn::memory(
186+
{diff_weights_md, mkldnn_engine},
187+
reinterpret_cast<void*>(const_cast<T*>(output_grad_data)));
186188
// Retrieve conv_pd from device context
187189
auto conv_pd =
188190
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
@@ -202,8 +204,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
202204
auto diff_weights_memory =
203205
mkldnn::memory({diff_weights_md, mkldnn_engine},
204206
reinterpret_cast<void*>(filter_grad_data));
205-
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
206-
reinterpret_cast<void*>(input_data));
207+
auto src_memory =
208+
mkldnn::memory({src_md, mkldnn_engine},
209+
reinterpret_cast<void*>(const_cast<T*>(input_data)));
207210

208211
// create backward conv primitive for weights
209212
auto conv_bwd_weights_prim = mkldnn::convolution_backward_weights(
@@ -222,11 +225,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
222225
strides, paddings, *conv_pd, mkldnn_engine);
223226

224227
// create memory
225-
auto diff_src_memory =
226-
mkldnn::memory({diff_src_md, mkldnn_engine},
227-
reinterpret_cast<void*>(input_grad_data));
228-
auto weights_memory = mkldnn::memory(
229-
{weights_md, mkldnn_engine}, reinterpret_cast<void*>(filter_data));
228+
auto diff_src_memory = mkldnn::memory(
229+
{diff_src_md, mkldnn_engine},
230+
reinterpret_cast<void*>(const_cast<T*>(input_grad_data)));
231+
auto weights_memory =
232+
mkldnn::memory({weights_md, mkldnn_engine},
233+
reinterpret_cast<void*>(const_cast<T*>(filter_data)));
230234

231235
// create backward conv primitive for data
232236
auto conv_bwd_data_prim = mkldnn::convolution_backward_data(

0 commit comments

Comments
 (0)