@@ -29,3 +29,47 @@ typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
2929to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda (ggml_type type);
3030to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda (ggml_type type);
3131to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda (ggml_type type);
32+
33+ template <typename src_t , typename dest_t >
34+ __host__ __device__ inline dest_t ggml_cuda_convert_val (src_t x) {
35+ if constexpr (std::is_same_v<src_t , dest_t >) {
36+ return x;
37+ } else {
38+ return float (x);
39+ }
40+ }
41+
42+ template <>
43+ __host__ __device__ inline float ggml_cuda_convert_val<nv_bfloat16, float >(nv_bfloat16 x) {
44+ return __bfloat162float (x);
45+ }
46+
47+ template <>
48+ __host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<nv_bfloat16, nv_bfloat16>(nv_bfloat16 x) {
49+ return x;
50+ }
51+
52+ template <>
53+ __host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<float , nv_bfloat16>(float x) {
54+ return __float2bfloat16 (x);
55+ }
56+
57+ template <>
58+ __host__ __device__ inline half ggml_cuda_convert_val<nv_bfloat16, half>(nv_bfloat16 x) {
59+ return half (__bfloat162float (x));
60+ }
61+
62+ template <>
63+ __host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<half, nv_bfloat16>(half x) {
64+ return __float2bfloat16 (float (x));
65+ }
66+
67+ template <>
68+ __host__ __device__ inline int ggml_cuda_convert_val<nv_bfloat16, int >(nv_bfloat16 x) {
69+ return int (__bfloat162float (x));
70+ }
71+
72+ template <>
73+ __host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<int , nv_bfloat16>(int x) {
74+ return __float2bfloat16 (float (x));
75+ }
0 commit comments