@@ -79,15 +79,15 @@ inline void get_unsqueezed_dim_order(
79
79
* in_C_per_group x in_H x in_W, to compute an out channel of size 1 x out_H x
80
80
* out_W.
81
81
*/
82
- template <typename CTYPE>
82
+ template <typename CTYPE, typename CTYPE_BIAS >
83
83
void conv2d_impl (
84
84
const CTYPE* const in_ptr,
85
85
SizesArrayRef in_sizes,
86
86
StridesArrayRef in_strides,
87
87
const CTYPE* const w_ptr,
88
88
SizesArrayRef w_sizes,
89
89
StridesArrayRef w_strides,
90
- const CTYPE * const bias_ptr,
90
+ const CTYPE_BIAS * const bias_ptr,
91
91
IntArrayRef stride,
92
92
IntArrayRef padding,
93
93
IntArrayRef dilation,
@@ -174,15 +174,15 @@ void conv2d_impl(
174
174
}
175
175
176
176
if (bias_ptr != nullptr ) {
177
- accum += bias_ptr[out_c];
177
+ accum += convert<CTYPE, CTYPE_BIAS>( bias_ptr[out_c]) ;
178
178
}
179
179
size_t out_idx = calculate_linear_index (out_coord, out_strides.data (), 4 );
180
180
out_ptr[out_idx] = accum;
181
181
}
182
182
}
183
183
}
184
184
185
- template <typename CTYPE>
185
+ template <typename CTYPE, typename CTYPE_BIAS >
186
186
void convolution_wrapper (
187
187
const Tensor& in,
188
188
const Tensor& weight,
@@ -289,8 +289,8 @@ void convolution_wrapper(
289
289
CTYPE* const out_ptr = out.mutable_data_ptr <CTYPE>();
290
290
const CTYPE* const in_ptr = in.const_data_ptr <CTYPE>();
291
291
const CTYPE* const w_ptr = weight.const_data_ptr <CTYPE>();
292
- const CTYPE * const bias_ptr =
293
- bias.has_value () ? bias.value ().const_data_ptr <CTYPE >() : nullptr ;
292
+ const CTYPE_BIAS * const bias_ptr =
293
+ bias.has_value () ? bias.value ().const_data_ptr <CTYPE_BIAS >() : nullptr ;
294
294
295
295
for (size_t batch = 0 ; batch < out_N; ++batch) {
296
296
for (size_t group = 0 ; group < groups; ++group) {
@@ -424,9 +424,16 @@ Tensor& convolution_out(
424
424
Error err = resize_tensor (out, {output_sizes, output_ndim});
425
425
ET_CHECK_MSG (err == Error::Ok, " Could not resize output" );
426
426
427
- ET_SWITCH_REAL_TYPES (in.scalar_type (), ctx, " convolution" , CTYPE, [&]() {
428
- convolution_wrapper<CTYPE>(
429
- in, weight, bias, stride, padding, dilation, groups, out);
427
+ ScalarType in_type = in.scalar_type ();
428
+ ScalarType bias_type = in_type;
429
+ if (bias.has_value ()) {
430
+ bias_type = bias.value ().scalar_type ();
431
+ }
432
+ ET_SWITCH_REAL_TYPES (in_type, ctx, __func__, CTYPE, [&]() {
433
+ ET_SWITCH_REAL_TYPES_AND (Bool, bias_type, ctx, __func__, CTYPE_BIAS, [&]() {
434
+ convolution_wrapper<CTYPE, CTYPE_BIAS>(
435
+ in, weight, bias, stride, padding, dilation, groups, out);
436
+ });
430
437
});
431
438
432
439
return out;
0 commit comments