22#include " common.cuh"
33#include " mmv.cuh"
44
5+ template <typename T, typename type_acc, int ncols_dst, int block_size>
56template <typename T, typename type_acc, int ncols_dst, int block_size>
67static __global__ void mul_mat_vec (
78 const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
@@ -15,10 +16,25 @@ static __global__ void mul_mat_vec(
1516 const int sample_dst = blockIdx .z ;
1617 const int sample_x = sample_dst / sample_ratio;
1718 const int sample_y = sample_dst;
19+ const int tid = threadIdx .x ;
20+
21+ const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
22+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
23+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
24+ const int row = blockIdx .x ;
25+ const int channel_dst = blockIdx .y ;
26+ const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
27+ const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
28+ const int sample_dst = blockIdx .z ;
29+ const int sample_x = sample_dst / sample_ratio;
30+ const int sample_y = sample_dst;
1831 const int tid = threadIdx .x ;
1932
2033 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
2134
35+ x += int64_t (sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
36+ y += int64_t (sample_y) *stride_sample_y + channel_y *stride_channel_y;
37+ dst += int64_t (sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
2238 x += int64_t (sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
2339 y += int64_t (sample_y) *stride_sample_y + channel_y *stride_channel_y;
2440 dst += int64_t (sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
@@ -456,11 +472,6 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
456472 return ne11 <= 4 ;
457473 }
458474 return ne11 <= 3 ;
459- } else if (GGML_CUDA_CC_IS_AMD (cc)) {
460- if (fp32_mma_hardware_available (cc)) {
461- return ne11 <= 3 ;
462- }
463- return ne11 <= 8 ;
464475 }
465476 return ne11 <= 8 ;
466477 case GGML_TYPE_F16:
@@ -473,14 +484,6 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
473484 return src0_small && ne11 <= 3 ;
474485 }
475486 return ne11 <= 8 ;
476- } else if (GGML_CUDA_CC_IS_AMD (cc)) {
477- if (fp16_mma_hardware_available (cc)) {
478- if (GGML_CUDA_CC_IS_RDNA3 (cc) || GGML_CUDA_CC_IS_RDNA4 (cc)) {
479- return ne11 <= 5 ;
480- }
481- return ne11 <= 2 ;
482- }
483- return ne11 <= 8 ;
484487 }
485488 return ne11 <= 8 ;
486489 case GGML_TYPE_BF16:
@@ -493,11 +496,6 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
493496 return src0_small && ne11 <= 3 ;
494497 }
495498 return ne11 <= 8 ;
496- } else if (GGML_CUDA_CC_IS_AMD (cc)) {
497- if (bf16_mma_hardware_available (cc)) {
498- return ne11 <= 3 ;
499- }
500- return ne11 <= 8 ;
501499 }
502500 return ne11 <= 8 ;
503501 default :
0 commit comments