@@ -52,8 +52,8 @@ class DnnlGemmWrapper {
5252 const auto a_in_md = dnnl::memory::desc (a_dims, at, a_trans ? tag::ba : tag::ab);
5353 const auto b_in_md = dnnl::memory::desc (b_dims, bt, b_trans ? tag::ba : tag::ab);
5454 const auto c_md = dnnl::memory::desc (c_dims, ct, tag::ab);
55- auto a_mem = dnnl::memory (a_in_md, eng, ( void *)a );
56- auto b_mem = dnnl::memory (b_in_md, eng, ( void *)b );
55+ auto a_mem = dnnl::memory (a_in_md, eng, const_cast < void *>(a) );
56+ auto b_mem = dnnl::memory (b_in_md, eng, const_cast < void *>(b) );
5757 auto matmul_pd = dnnl::matmul::primitive_desc (eng, a_in_md, b_in_md, c_md);
5858 auto c_mem = dnnl::memory (matmul_pd.dst_desc (), eng, c);
5959
@@ -80,8 +80,8 @@ class DnnlGemmWrapper {
8080 const auto a_in_md = dnnl::memory::desc (a_dims, at, a_trans ? tag::ba : tag::ab);
8181 const auto b_in_md = dnnl::memory::desc (b_dims, bt, b_trans ? tag::ba : tag::ab);
8282 const auto c_md = dnnl::memory::desc (c_dims, ct, tag::ab);
83- auto a_mem = dnnl::memory (a_in_md, eng, ( void *)a );
84- auto b_mem = dnnl::memory (b_in_md, eng, ( void *)b );
83+ auto a_mem = dnnl::memory (a_in_md, eng, const_cast < void *>(a) );
84+ auto b_mem = dnnl::memory (b_in_md, eng, const_cast < void *>(b) );
8585 auto matmul_pd = dnnl::matmul::primitive_desc (eng, a_in_md, b_in_md, c_md);
8686 auto c_mem = dnnl::memory (matmul_pd.dst_desc (), eng, c);
8787
0 commit comments