@@ -120,7 +120,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
120
120
math::matmul<Place, T>(context.device_context (), filter, true ,
121
121
input_batch, false , T (1.0 ), &col_matrix, T (0.0 ));
122
122
col2im (context.device_context (), output_batch, col, strides[0 ],
123
- strides[1 ], 0 , 0 );
123
+ strides[1 ], 0 , 0 , 0 , 0 );
124
124
}
125
125
}
126
126
};
@@ -206,7 +206,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
206
206
207
207
// im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w)
208
208
im2col (context.device_context (), output_grad_batch, col, strides[0 ],
209
- strides[1 ], paddings[0 ], paddings[1 ]);
209
+ strides[1 ], paddings[0 ], paddings[0 ], paddings[ 1 ], paddings[ 1 ]);
210
210
211
211
// gemm: dx = filter * dy
212
212
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h)
@@ -238,7 +238,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
238
238
239
239
// im2col: (c * h * w, k_h * k_w)
240
240
im2col (context.device_context (), output_grad_batch, col, strides[0 ],
241
- strides[1 ], paddings[0 ], paddings[1 ]);
241
+ strides[1 ], paddings[0 ], paddings[0 ], paddings[ 1 ], paddings[ 1 ]);
242
242
243
243
// gemm: d_filter = x * y_grad^T
244
244
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h)
0 commit comments