Skip to content

Conversation

@qnixsynapse
Copy link
Collaborator

@qnixsynapse qnixsynapse commented Aug 31, 2025

Building conv2d with half precision failed because __half defines multiple implicit conversion operators (to float, int, short, etc.), causing ambiguous overload resolution when multiplying with float.

Introduce a templated to_float helper that explicitly converts __half via __half2float, while passing through float unchanged. Use this helper in conv2d accumulation to ensure unambiguous and correct promotion to float.

Use ggml_cuda_cast from convert.cuh for casting the value to float instead.

Fixes some build errors with half-precision kernels on CUDA.

Building conv2d with half precision failed because `__half` defines
multiple implicit conversion operators (to float, int, short, etc.),
causing ambiguous overload resolution when multiplying with float.

Introduce a templated `to_float` helper that explicitly converts
`__half` via `__half2float`, while passing through float unchanged.
Use this helper in conv2d accumulation to ensure unambiguous and
correct promotion to float.

Fixes some build errors with half-precision kernels on CUDA.

ggml-ci
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Aug 31, 2025
@qnixsynapse
Copy link
Collaborator Author

qnixsynapse commented Aug 31, 2025

Seems CI is failing because of #15434 hence unrelated. The CI scripts needs to be updated.

@JohannesGaessler
Copy link
Collaborator

Which build errors?

In any case, in convert.cuh there is already a template ggml_cuda_cast for something like this.

@qnixsynapse
Copy link
Collaborator Author

@JohannesGaessler Here is the error log. It suddenly started failed after conv2D got merged.

/home/runner/actions-runner/_work/llama.cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu(104): error: more than one conversion function from "half" to a built-in type applies:
            function "__half::operator float() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(204): here
            function "__half::operator short() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(222): here
            function "__half::operator unsigned short() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(225): here
            function "__half::operator int() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(228): here
            function "__half::operator unsigned int() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(231): here
            function "__half::operator long long() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(234): here
            function "__half::operator unsigned long long() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(237): here
            function "__half::operator __nv_bool() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(241): here
          detected during:
            instantiation of "void conv2d_kernel<T,Layout>(const float *, const T *, float *, conv_params) [with T=half, Layout=whcn_layout]" 
(116): here
            instantiation of "void conv2d_cuda(const float *, const T *, float *, conv_params, cudaStream_t) [with T=half]" 
(120): here

/home/runner/actions-runner/_work/llama.cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu(104): error: more than one conversion function from "half" to a built-in type applies:
            function "__half::operator float() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(204): here
            function "__half::operator short() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(222): here
            function "__half::operator unsigned short() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(225): here
            function "__half::operator int() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(228): here
            function "__half::operator unsigned int() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(231): here
            function "__half::operator long long() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(234): here
            function "__half::operator unsigned long long() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(237): here
            function "__half::operator __nv_bool() const"
/usr/local/cuda/targets/x86_64-linux/include/cuda_fp16.hpp(241): here
          detected during:
            instantiation of "void conv2d_kernel<T,Layout>(const float *, const T *, float *, conv_params) [with T=half, Layout=whcn_layout]" 
(116): here
            instantiation of "void conv2d_cuda(const float *, const T *, float *, conv_params, cudaStream_t) [with T=half]" 
(120): here

2 errors detected in the compilation of "/home/runner/actions-runner/_work/llama.cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu".
make[4]: *** [ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/build.make:242: ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/conv2d.cu.o] Error 1
make[4]: *** Waiting for unfinished jobs....
[ 23%] Linking CXX static library libggml-cpu.a
make[4]: Leaving directory '/home/runner/actions-runner/_work/llama.cpp/llama.cpp/build'
[ 23%] Built target ggml-cpu
make[3]: *** [CMakeFiles/Makefile2:987: ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/all] Error 2
make[4]: Leaving directory '/home/runner/actions-runner/_work/llama.cpp/llama.cpp/build'
make[2]: *** [CMakeFiles/Makefile2:2240: tools/server/CMakeFiles/llama-server.dir/rule] Error 2
make[1]: *** [Makefile:728: llama-server] Error 2
make: *** [Makefile:30: build-lib] Error 2
make[3]: Leaving directory '/home/runner/actions-runner/_work/llama.cpp/llama.cpp/build'
make[2]: Leaving directory '/home/runner/actions-runner/_work/llama.cpp/llama.cpp/build'
make[1]: Leaving directory '/home/runner/actions-runner/_work/llama.cpp/llama.cpp/build'

In any case, in convert.cuh there is already a template ggml_cuda_cast for something like this.

I see. Lemme check and update it.

@JohannesGaessler
Copy link
Collaborator

I see. Lemme check and update it.

I don't think you need to update template, if it was ambiguous you should already be getting compilation failures at other points in the code.

@qnixsynapse
Copy link
Collaborator Author

qnixsynapse commented Aug 31, 2025

if it was ambiguous you should already be getting compilation failures at other points in the code.

I see. Fortunately, this change fixed the build in our testing. If you want me to use the template ggml_cuda_cast from convert.cuh here, please let me know.

@JohannesGaessler
Copy link
Collaborator

Please just use the template in convert.cuh, extend the template if and only if you still get compilation failures using it.

@JohannesGaessler JohannesGaessler dismissed their stale review August 31, 2025 09:35

Forgot to check CI passing.

@JohannesGaessler
Copy link
Collaborator

As I said:

extend the template if and only if you still get compilation failures using it.

The code did not compile in the first place due to the missing header. So did you actually confirm that the additional branch in the template is necessary?

@qnixsynapse
Copy link
Collaborator Author

Not yet. I will let you know when I do that. Should I convert this PR to draft?

@JohannesGaessler
Copy link
Collaborator

I think it doesn't really matter, just request a review when ready.

@qnixsynapse
Copy link
Collaborator Author

Okay, I think it is good to go now. It is now building without failure and test-backend-ops passed successfully.

const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
acc += (input_val * kernel_val);
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
acc += (input_val * ggml_cuda_cast<float, T>(kernel_val));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
acc += (input_val * ggml_cuda_cast<float, T>(kernel_val));
acc += (input_val * ggml_cuda_cast<float>(kernel_val));

The second type can be inferred from the argument, I think this is more legible.

@qnixsynapse qnixsynapse merged commit b66df9d into master Sep 1, 2025
46 of 48 checks passed
@qnixsynapse qnixsynapse deleted the cuda/fix_conv2d_type branch September 1, 2025 01:25
Minh141120 pushed a commit to menloresearch/llama.cpp that referenced this pull request Sep 5, 2025
…ml-org#15690)

* CUDA: fix build error from ambiguous __half conversions in conv2d

Building conv2d with half precision failed because `__half` defines
multiple implicit conversion operators (to float, int, short, etc.),
causing ambiguous overload resolution when multiplying with float.

Introduce a templated `to_float` helper that explicitly converts
`__half` via `__half2float`, while passing through float unchanged.
Use this helper in conv2d accumulation to ensure unambiguous and
correct promotion to float.

Fixes some build errors with half-precision kernels on CUDA.

ggml-ci

* CUDA: Replace custom to_float helper with unified ggml_cuda_cast and add half‑>float conversion

* CUDA: Add missing convert.cuh header

* CUDA: remove unnecessary extension in ggml_cuda_cast

* CUDA: Address review comment, remove second type template argument
walidbr pushed a commit to walidbr/llama.cpp that referenced this pull request Sep 7, 2025
…ml-org#15690)

* CUDA: fix build error from ambiguous __half conversions in conv2d

Building conv2d with half precision failed because `__half` defines
multiple implicit conversion operators (to float, int, short, etc.),
causing ambiguous overload resolution when multiplying with float.

Introduce a templated `to_float` helper that explicitly converts
`__half` via `__half2float`, while passing through float unchanged.
Use this helper in conv2d accumulation to ensure unambiguous and
correct promotion to float.

Fixes some build errors with half-precision kernels on CUDA.

ggml-ci

* CUDA: Replace custom to_float helper with unified ggml_cuda_cast and add half‑>float conversion

* CUDA: Add missing convert.cuh header

* CUDA: remove unnecessary extension in ggml_cuda_cast

* CUDA: Address review comment, remove second type template argument
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants