11#include  " common.cuh" 
22#include  " mmv.cuh" 
33
4- template  <typename  type_acc, int  block_size>
4+ template  <typename  T,  typename   type_acc, int  block_size>
55static  __global__  void  mul_mat_vec (
6-         const  half  * __restrict__  x, const  float  * __restrict__  y, float  * __restrict__  dst, const  int64_t  ncols2, const  int64_t  stride_row,
6+         const  T  * __restrict__  x, const  float  * __restrict__  y, float  * __restrict__  dst, const  int64_t  ncols2, const  int64_t  stride_row,
77        const  int64_t  channel_ratio, const  int64_t  stride_channel_x, const  int64_t  stride_channel_y, const  int64_t  stride_channel_dst) {
88    const  int64_t  row     = blockIdx .x ;
99    const  int64_t  channel = blockIdx .z ;
@@ -13,7 +13,6 @@ static __global__ void mul_mat_vec(
1313    y   +=  channel               *stride_channel_y;
1414    dst +=  channel               *stride_channel_dst;
1515
16-     const  half2  * x2 = (const  half2  *) x;
1716    const  float2  * y2 = (const  float2  *) y;
1817
1918    extern  __shared__  char  data_mmv[];
@@ -28,28 +27,44 @@ static __global__ void mul_mat_vec(
2827
2928    float  sumf;
3029
31-     if  (std::is_same<type_acc,  float >::value) {
32-         sumf =  0 . 0f ;
30+     if  constexpr   (std::is_same<T, half >::value) {
31+         const  half2 * x2 = ( const  half2 *) x ;
3332
34-         for  (int64_t  col2 = tid; col2 < ncols2; col2 += block_size) {
35-             const  float2  tmpx = __half22float2 (x2[col2]);
36-             const  float2  tmpy = y2[col2];
37-             sumf += tmpx.x  * tmpy.x ;
38-             sumf += tmpx.y  * tmpy.y ;
39-         }
40-     } else  {
33+         if  (std::is_same<type_acc, float >::value) {
34+             sumf = 0 .0f ;
35+ 
36+             for  (int64_t  col2 = tid; col2 < ncols2; col2 += block_size) {
37+                 const  float2  tmpx = __half22float2 (x2[col2]);
38+                 const  float2  tmpy = y2[col2];
39+                 sumf += tmpx.x  * tmpy.x ;
40+                 sumf += tmpx.y  * tmpy.y ;
41+             }
42+         } else  {
4143#ifdef  FP16_AVAILABLE
42-         half2 sumh2 = make_half2 (0 .0f , 0 .0f );
44+              half2 sumh2 = make_half2 (0 .0f , 0 .0f );
4345
44-         for  (int64_t  col2 = tid; col2 < ncols2; col2 += block_size) {
45-             const  float2  tmp = y2[col2];
46-             sumh2 += x2[col2] * make_half2 (tmp.x , tmp.y );
47-         }
46+              for  (int64_t  col2 = tid; col2 < ncols2; col2 += block_size) {
47+                  const  float2  tmp = y2[col2];
48+                  sumh2 += x2[col2] * make_half2 (tmp.x , tmp.y );
49+              }
4850
49-         sumf = __low2float (sumh2) + __high2float (sumh2);
51+              sumf = __low2float (sumh2) + __high2float (sumh2);
5052#else 
51-         NO_DEVICE_CODE;
53+              NO_DEVICE_CODE;
5254#endif  //  FP16_AVAILABLE
55+         }
56+     } else  if  constexpr  (std::is_same<T, nv_bfloat16>::value) {
57+         const  int  * x2 = (const  int  *) x;
58+         sumf = 0 .0f ;
59+ 
60+         for  (int64_t  col2 = tid; col2 < ncols2; col2 += block_size) {
61+             const  int     tmpx = x2[col2];
62+             const  float2  tmpy = y2[col2];
63+             sumf += float (reinterpret_cast <const  nv_bfloat16 *>(&tmpx)[0 ]) * tmpy.x ;
64+             sumf += float (reinterpret_cast <const  nv_bfloat16 *>(&tmpx)[1 ]) * tmpy.y ;
65+         }
66+     } else  {
67+         static_assert (std::is_same<T, void >::value, " unsupported type" 
5368    }
5469
5570    sumf = warp_reduce_sum (sumf);
@@ -71,9 +86,9 @@ static __global__ void mul_mat_vec(
7186    dst[row] = sumf;
7287}
7388
74- template  <typename  type_acc>
89+ template  <typename  T,  typename   type_acc>
7590static  void  launch_mul_mat_vec_cuda (
76-         const  half  * x, const  float  * y, float  * dst,
91+         const  T  * x, const  float  * y, float  * dst,
7792        const  int64_t  ncols, const  int64_t  nrows, const  int64_t  stride_row, const  int64_t  nchannels_x, const  int64_t  nchannels_y,
7893        const  int64_t  stride_channel_x, const  int64_t  stride_channel_y, const  int64_t  stride_channel_dst,
7994        cudaStream_t stream) {
@@ -97,35 +112,35 @@ static void launch_mul_mat_vec_cuda(
97112    const  dim3  block_dims (block_size_best, 1 , 1 );
98113    switch  (block_size_best) {
99114        case    32 : {
100-             mul_mat_vec<type_acc,  32 ><<<block_nums, block_dims, smem, stream>>> 
115+             mul_mat_vec<T,  type_acc,  32 ><<<block_nums, block_dims, smem, stream>>> 
101116                (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
102117        } break ;
103118        case    64 : {
104-             mul_mat_vec<type_acc,  64 ><<<block_nums, block_dims, smem, stream>>> 
119+             mul_mat_vec<T,  type_acc,  64 ><<<block_nums, block_dims, smem, stream>>> 
105120                (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
106121        } break ;
107122        case    96 : {
108-             mul_mat_vec<type_acc,  96 ><<<block_nums, block_dims, smem, stream>>> 
123+             mul_mat_vec<T,  type_acc,  96 ><<<block_nums, block_dims, smem, stream>>> 
109124                (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
110125        } break ;
111126        case   128 : {
112-             mul_mat_vec<type_acc, 128 ><<<block_nums, block_dims, smem, stream>>> 
127+             mul_mat_vec<T,  type_acc, 128 ><<<block_nums, block_dims, smem, stream>>> 
113128                (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
114129        } break ;
115130        case   160 : {
116-             mul_mat_vec<type_acc, 160 ><<<block_nums, block_dims, smem, stream>>> 
131+             mul_mat_vec<T,  type_acc, 160 ><<<block_nums, block_dims, smem, stream>>> 
117132                (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
118133        } break ;
119134        case   192 : {
120-             mul_mat_vec<type_acc, 192 ><<<block_nums, block_dims, smem, stream>>> 
135+             mul_mat_vec<T,  type_acc, 192 ><<<block_nums, block_dims, smem, stream>>> 
121136                (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
122137        } break ;
123138        case   224 : {
124-             mul_mat_vec<type_acc, 224 ><<<block_nums, block_dims, smem, stream>>> 
139+             mul_mat_vec<T,  type_acc, 224 ><<<block_nums, block_dims, smem, stream>>> 
125140                (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
126141        } break ;
127142        case   256 : {
128-             mul_mat_vec<type_acc, 256 ><<<block_nums, block_dims, smem, stream>>> 
143+             mul_mat_vec<T,  type_acc, 256 ><<<block_nums, block_dims, smem, stream>>> 
129144                (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
130145        } break ;
131146        default : {
@@ -134,25 +149,25 @@ static void launch_mul_mat_vec_cuda(
134149    }
135150}
136151
152+ template <typename  T>
137153static  void  mul_mat_vec_cuda (
138-         const  half  * x, const  float  * y, float  * dst,
154+         const  T  * x, const  float  * y, float  * dst,
139155        const  int64_t  ncols, const  int64_t  nrows, const  int64_t  stride_row, const  int64_t  nchannels_x, const  int64_t  nchannels_y,
140156        const  int64_t  stride_channel_x, const  int64_t  stride_channel_y, const  int64_t  stride_channel_dst,
141157        enum  ggml_prec prec, cudaStream_t stream) {
142158    switch  (prec) {
143159        case  GGML_PREC_DEFAULT: {
144-             launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
160+             launch_mul_mat_vec_cuda<T,  half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
145161                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
146162        } break ;
147163        case  GGML_PREC_F32: {
148-             launch_mul_mat_vec_cuda<float >(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
164+             launch_mul_mat_vec_cuda<T,  float >(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
149165                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
150166        } break ;
151167    }
152168}
153169
154170void  ggml_cuda_mul_mat_vec (ggml_backend_cuda_context & ctx, const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
155-     GGML_ASSERT (src0->type  == GGML_TYPE_F16);
156171    GGML_ASSERT (src1->type  == GGML_TYPE_F32);
157172    GGML_ASSERT (dst->type   == GGML_TYPE_F32);
158173
@@ -164,7 +179,6 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
164179    const  int  cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
165180    const  enum  ggml_prec prec = fast_fp16_available (cc) ? ggml_prec (dst->op_params [0 ]) : GGML_PREC_F32;
166181
167-     const  half  * src0_d = (const  half  *) src0->data ;
168182    const  float  * src1_d = (const  float  *) src1->data ;
169183    float        *  dst_d = (float        *)  dst->data ;
170184
@@ -181,7 +195,20 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
181195    const  int64_t  channel_stride_y   = src1->nb [2 ] / ggml_type_size (src1->type );
182196    const  int64_t  channel_stride_dst =  dst->nb [2 ] / ggml_type_size ( dst->type );
183197
184-     mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream ());
198+     switch  (src0->type ) {
199+         case  GGML_TYPE_F16: {
200+             const  half * src0_d = (const  half *) src0->data ;
201+             mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
202+                 channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream ());
203+         } break ;
204+         case  GGML_TYPE_BF16: {
205+             const  nv_bfloat16 * src0_d = (const  nv_bfloat16 *) src0->data ;
206+             mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
207+                 channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream ());
208+         } break ;
209+         default :
210+             GGML_ABORT (" unsupported type: %s" ggml_type_name (src0->type ));
211+     }
185212}
186213
187214void  ggml_cuda_op_mul_mat_vec (
@@ -190,7 +217,6 @@ void ggml_cuda_op_mul_mat_vec(
190217    const  char  * src1_ddq_i, float  * dst_dd_i, const  int64_t  row_low, const  int64_t  row_high, const  int64_t  src1_ncols,
191218    const  int64_t  src1_padded_row_size, cudaStream_t stream) {
192219
193-     GGML_ASSERT (src0->type  == GGML_TYPE_F16);
194220    GGML_ASSERT (src1->type  == GGML_TYPE_F32);
195221    GGML_ASSERT (dst->type   == GGML_TYPE_F32);
196222
@@ -211,8 +237,20 @@ void ggml_cuda_op_mul_mat_vec(
211237    const  int64_t  channel_stride_y   = 0 ;
212238    const  int64_t  channel_stride_dst = 0 ;
213239
214-     mul_mat_vec_cuda ((const  half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
215-         nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
240+     switch  (src0->type ) {
241+         case  GGML_TYPE_F16: {
242+             const  half * src0_d = (const  half *) src0_dd_i;
243+             mul_mat_vec_cuda (src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
244+                 nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
245+         } break ;
246+         case  GGML_TYPE_BF16: {
247+             const  nv_bfloat16 * src0_d = (const  nv_bfloat16 *) src0_dd_i;
248+             mul_mat_vec_cuda (src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
249+                 nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
250+         } break ;
251+         default :
252+             GGML_ABORT (" unsupported type: %s" ggml_type_name (src0->type ));
253+     }
216254
217255    GGML_UNUSED (ctx);
218256    GGML_UNUSED (src1);
0 commit comments