diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index bcb70762ee05e..142dd66903aaa 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,4 +1,5 @@ #include "conv2d.cuh" +#include "convert.cuh" struct conv_params { const int64_t IW, IH; @@ -94,8 +95,8 @@ static __global__ void conv2d_kernel(const float * __restrict__ input, const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X); const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)]; - const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; - acc += (input_val * kernel_val); + const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; + acc += (input_val * ggml_cuda_cast(kernel_val)); } } }