Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ggml/src/ggml-cuda/conv2d.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "conv2d.cuh"
#include "convert.cuh"

struct conv_params {
const int64_t IW, IH;
Expand Down Expand Up @@ -94,8 +95,8 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);

const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
acc += (input_val * kernel_val);
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
acc += (input_val * ggml_cuda_cast<float, T>(kernel_val));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
acc += (input_val * ggml_cuda_cast<float, T>(kernel_val));
acc += (input_val * ggml_cuda_cast<float>(kernel_val));

The second type can be inferred from the argument, I think this is more legible.

}
}
}
Expand Down
Loading