Skip to content

Commit af67e28

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: convolution
Reviewed By: SS-JIA Differential Revision: D48288079 fbshipit-source-id: 3b31f378e4df6d0df4b78a69c569b5154195271d
1 parent 827388d commit af67e28

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

kernels/portable/cpu/op_convolution.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,15 @@ inline void get_unsqueezed_dim_order(
7979
* in_C_per_group x in_H x in_W, to compute an out channel of size 1 x out_H x
8080
* out_W.
8181
*/
82-
template <typename CTYPE>
82+
template <typename CTYPE, typename CTYPE_BIAS>
8383
void conv2d_impl(
8484
const CTYPE* const in_ptr,
8585
SizesArrayRef in_sizes,
8686
StridesArrayRef in_strides,
8787
const CTYPE* const w_ptr,
8888
SizesArrayRef w_sizes,
8989
StridesArrayRef w_strides,
90-
const CTYPE* const bias_ptr,
90+
const CTYPE_BIAS* const bias_ptr,
9191
IntArrayRef stride,
9292
IntArrayRef padding,
9393
IntArrayRef dilation,
@@ -174,15 +174,15 @@ void conv2d_impl(
174174
}
175175

176176
if (bias_ptr != nullptr) {
177-
accum += bias_ptr[out_c];
177+
accum += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
178178
}
179179
size_t out_idx = calculate_linear_index(out_coord, out_strides.data(), 4);
180180
out_ptr[out_idx] = accum;
181181
}
182182
}
183183
}
184184

185-
template <typename CTYPE>
185+
template <typename CTYPE, typename CTYPE_BIAS>
186186
void convolution_wrapper(
187187
const Tensor& in,
188188
const Tensor& weight,
@@ -289,8 +289,8 @@ void convolution_wrapper(
289289
CTYPE* const out_ptr = out.mutable_data_ptr<CTYPE>();
290290
const CTYPE* const in_ptr = in.const_data_ptr<CTYPE>();
291291
const CTYPE* const w_ptr = weight.const_data_ptr<CTYPE>();
292-
const CTYPE* const bias_ptr =
293-
bias.has_value() ? bias.value().const_data_ptr<CTYPE>() : nullptr;
292+
const CTYPE_BIAS* const bias_ptr =
293+
bias.has_value() ? bias.value().const_data_ptr<CTYPE_BIAS>() : nullptr;
294294

295295
for (size_t batch = 0; batch < out_N; ++batch) {
296296
for (size_t group = 0; group < groups; ++group) {
@@ -424,9 +424,16 @@ Tensor& convolution_out(
424424
Error err = resize_tensor(out, {output_sizes, output_ndim});
425425
ET_CHECK_MSG(err == Error::Ok, "Could not resize output");
426426

427-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "convolution", CTYPE, [&]() {
428-
convolution_wrapper<CTYPE>(
429-
in, weight, bias, stride, padding, dilation, groups, out);
427+
ScalarType in_type = in.scalar_type();
428+
ScalarType bias_type = in_type;
429+
if (bias.has_value()) {
430+
bias_type = bias.value().scalar_type();
431+
}
432+
ET_SWITCH_REAL_TYPES(in_type, ctx, __func__, CTYPE, [&]() {
433+
ET_SWITCH_REAL_TYPES_AND(Bool, bias_type, ctx, __func__, CTYPE_BIAS, [&]() {
434+
convolution_wrapper<CTYPE, CTYPE_BIAS>(
435+
in, weight, bias, stride, padding, dilation, groups, out);
436+
});
430437
});
431438

432439
return out;

0 commit comments

Comments
 (0)