File tree Expand file tree Collapse file tree 1 file changed +19
-3
lines changed Expand file tree Collapse file tree 1 file changed +19
-3
lines changed Original file line number Diff line number Diff line change @@ -34,23 +34,39 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
3434template <typename src_t , typename dest_t >
3535 __host__ __device__ inline dest_t ggml_cuda_convert_val (src_t x)
3636{
37- return float (x);
37+ if constexpr (std::is_same_v<src_t , dest_t >) {
38+ return x;
39+ } else {
40+ return float (x);
41+ }
3842}
3943
4044template <>
4145__host__ __device__ inline float ggml_cuda_convert_val<nv_bfloat16, float >(nv_bfloat16 x)
4246{
43- return __bfloat162float (x);
47+ return __bfloat162float (x);
4448}
4549
4650template <>
4751__host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<nv_bfloat16, nv_bfloat16>(nv_bfloat16 x)
4852{
49- return x;
53+ return x;
5054}
5155
5256template <>
5357__host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<float , nv_bfloat16>(float x)
5458{
5559 return __float2bfloat16 (x);
5660}
61+
62+ template <>
63+ __host__ __device__ inline half ggml_cuda_convert_val<nv_bfloat16, half>(nv_bfloat16 x)
64+ {
65+ return half (__bfloat162float (x));
66+ }
67+
68+ template <>
69+ __host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<half, nv_bfloat16>(half x)
70+ {
71+ return __float2bfloat16 (float (x));
72+ }
You can’t perform that action at this time.
0 commit comments