Skip to content

Commit 38ad381

Browse files
CUDA: use FP32 arithmetic for conv2d (ggml-org#15683)
1 parent 696fccf commit 38ad381

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

ggml/src/ggml-cuda/conv2d.cu

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
8282
int64_t n, c_out, out_y, out_x;
8383
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
8484

85-
T acc = 0;
85+
float acc = 0.0f;
8686

8787
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
8888
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
@@ -93,21 +93,15 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
9393
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
9494
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
9595

96-
T input_val;
97-
if (std::is_same<T, half>::value) {
98-
input_val = __float2half(input[Layout::input_index(n, c_in, in_y, in_x, P)]);
99-
} else {
100-
input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
101-
}
102-
103-
T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
96+
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
97+
const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
10498
acc += (input_val * kernel_val);
10599
}
106100
}
107101
}
108102

109103
// [N, OC, OH, OW]
110-
output[Layout::output_index(n, c_out, out_y, out_x, P)] = (float) acc;
104+
output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
111105
}
112106

113107
template <typename T>

0 commit comments

Comments
 (0)