Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/conv2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);

const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
acc += (input_val * kernel_val);
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
acc += (input_val * ggml_cuda_cast<float, T>(kernel_val));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
acc += (input_val * ggml_cuda_cast<float, T>(kernel_val));
acc += (input_val * ggml_cuda_cast<float>(kernel_val));

The second type can be inferred from the argument, I think this is more legible.

}
}
}
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ template<typename dst_t, typename src_t>
__host__ __device__ inline dst_t ggml_cuda_cast(src_t x) {
if constexpr (std::is_same_v<dst_t, src_t>) {
return x;
} else if constexpr (std::is_same_v<dst_t, float> && std::is_same_v<src_t, half>) {
return __half2float(x);
} else if constexpr(std::is_same_v<dst_t, nv_bfloat16>) {
return __float2bfloat16(float(x));
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
Expand Down
Loading