Skip to content

Commit 4738fa6

Browse files
committed
CUDA/HIP: spechalize ggml_cuda_convert_val for half <-> bf16 conversions
1 parent 4178180 commit 4738fa6

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

ggml/src/ggml-cuda/convert.cuh

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,39 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
3434
template<typename src_t, typename dest_t>
3535
__host__ __device__ inline dest_t ggml_cuda_convert_val(src_t x)
3636
{
37-
return float(x);
37+
if constexpr (std::is_same_v<src_t, dest_t>) {
38+
return x;
39+
} else {
40+
return float(x);
41+
}
3842
}
3943

4044
template<>
4145
__host__ __device__ inline float ggml_cuda_convert_val<nv_bfloat16, float>(nv_bfloat16 x)
4246
{
43-
return __bfloat162float(x);
47+
return __bfloat162float(x);
4448
}
4549

4650
template<>
4751
__host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<nv_bfloat16, nv_bfloat16>(nv_bfloat16 x)
4852
{
49-
return x;
53+
return x;
5054
}
5155

5256
template<>
5357
__host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<float, nv_bfloat16>(float x)
5458
{
5559
return __float2bfloat16(x);
5660
}
61+
62+
template<>
63+
__host__ __device__ inline half ggml_cuda_convert_val<nv_bfloat16, half>(nv_bfloat16 x)
64+
{
65+
return half(__bfloat162float(x));
66+
}
67+
68+
template<>
69+
__host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<half, nv_bfloat16>(half x)
70+
{
71+
return __float2bfloat16(float(x));
72+
}

0 commit comments

Comments
 (0)