@@ -82,7 +82,7 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
82
82
int64_t n, c_out, out_y, out_x;
83
83
Layout::unpack_indices (global_idx, P, n, c_out, out_y, out_x);
84
84
85
- T acc = 0 ;
85
+ float acc = 0 . 0f ;
86
86
87
87
for (int64_t c_in = 0 ; c_in < P.IC ; ++c_in) {
88
88
kernel_bounds bounds = calculate_kernel_bounds (out_x, out_y, P);
@@ -93,21 +93,15 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
93
93
for (int64_t kx = bounds.x_min ; kx < bounds.x_max ; ++kx) {
94
94
const int64_t in_x = calculate_input_coord (out_x, kx, P.ST_X , P.DL_X , P.PD_X );
95
95
96
- T input_val;
97
- if (std::is_same<T, half>::value) {
98
- input_val = __float2half (input[Layout::input_index (n, c_in, in_y, in_x, P)]);
99
- } else {
100
- input_val = input[Layout::input_index (n, c_in, in_y, in_x, P)];
101
- }
102
-
103
- T kernel_val = kernel[Layout::kernel_index (c_out, c_in, ky, kx, P)];
96
+ const float input_val = input[Layout::input_index (n, c_in, in_y, in_x, P)];
97
+ const float kernel_val = kernel[Layout::kernel_index (c_out, c_in, ky, kx, P)];
104
98
acc += (input_val * kernel_val);
105
99
}
106
100
}
107
101
}
108
102
109
103
// [N, OC, OH, OW]
110
- output[Layout::output_index (n, c_out, out_y, out_x, P)] = ( float ) acc;
104
+ output[Layout::output_index (n, c_out, out_y, out_x, P)] = acc;
111
105
}
112
106
113
107
template <typename T>
0 commit comments