@@ -135,7 +135,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
135
135
136
136
// col_matrix = filter * input_batch
137
137
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
138
- blas.MatMul (filter, true , input_batch, false , &col_matrix);
138
+ blas.MatMul (filter, true , input_batch, false , static_cast <T>(1.0 ),
139
+ &col_matrix, static_cast <T>(0.0 ));
139
140
140
141
if (data_dim == 2U ) {
141
142
// col2im: col_matrix -> dy
@@ -267,7 +268,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
267
268
// or
268
269
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
269
270
// d, h, w)
270
- blas.MatMul (filter, false , col_matrix, false , &input_grad_batch);
271
+ blas.MatMul (filter, false , col_matrix, false , static_cast <T>(1.0 ),
272
+ &input_grad_batch, static_cast <T>(0.0 ));
271
273
}
272
274
if (filter_grad) {
273
275
// input batch
@@ -277,7 +279,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
277
279
// or
278
280
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
279
281
// k_h * k_w)
280
- blas.MatMul (in_batch, false , col_matrix, true , &filter_grad_);
282
+ blas.MatMul (in_batch, false , col_matrix, true , static_cast <T>(1.0 ),
283
+ &filter_grad_, static_cast <T>(1.0 ));
281
284
}
282
285
}
283
286
}
0 commit comments