@@ -29,3 +29,56 @@ typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
2929to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda (ggml_type type);
3030to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda (ggml_type type);
3131to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda (ggml_type type);
32+
33+
34+ template <typename src_t , typename dest_t >
35+ __host__ __device__ inline dest_t ggml_cuda_convert_val (src_t x)
36+ {
37+ if constexpr (std::is_same_v<src_t , dest_t >) {
38+ return x;
39+ } else {
40+ return float (x);
41+ }
42+ }
43+
44+ template <>
45+ __host__ __device__ inline float ggml_cuda_convert_val<nv_bfloat16, float >(nv_bfloat16 x)
46+ {
47+ return __bfloat162float (x);
48+ }
49+
50+ template <>
51+ __host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<nv_bfloat16, nv_bfloat16>(nv_bfloat16 x)
52+ {
53+ return x;
54+ }
55+
56+ template <>
57+ __host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<float , nv_bfloat16>(float x)
58+ {
59+ return __float2bfloat16 (x);
60+ }
61+
62+ template <>
63+ __host__ __device__ inline half ggml_cuda_convert_val<nv_bfloat16, half>(nv_bfloat16 x)
64+ {
65+ return half (__bfloat162float (x));
66+ }
67+
68+ template <>
69+ __host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<half, nv_bfloat16>(half x)
70+ {
71+ return __float2bfloat16 (float (x));
72+ }
73+
74+ template <>
75+ __host__ __device__ inline int ggml_cuda_convert_val<nv_bfloat16, int >(nv_bfloat16 x)
76+ {
77+ return int (__bfloat162float (x));
78+ }
79+
80+ template <>
81+ __host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<int , nv_bfloat16>(int x)
82+ {
83+ return __float2bfloat16 (float (x));
84+ }
0 commit comments