@@ -1410,6 +1410,11 @@ static void ggml_cuda_op_mul_mat(
14101410    const  int64_t  ne0 = dst->ne [0 ];
14111411    const  int64_t  ne1 = dst->ne [1 ];
14121412
1413+     //  const int64_t nb10 = src1->nb[0];
1414+     const  int64_t  nb11 = src1->nb [1 ];
1415+     const  int64_t  nb12 = src1->nb [2 ];
1416+     const  int64_t  nb13 = src1->nb [3 ];
1417+ 
14131418    const  int64_t  nb2 = dst->nb [2 ];
14141419    const  int64_t  nb3 = dst->nb [3 ];
14151420
@@ -1545,7 +1550,10 @@ static void ggml_cuda_op_mul_mat(
15451550            dev[id].src1_ddq  = dev[id].src1_ddq_alloc .alloc (ctx.pool (id), src_1_ddq_size);
15461551
15471552            if  (src1_on_device && src1_is_contiguous) {
1548-                 quantize_src1 (dev[id].src1_ddf , dev[id].src1_ddq , ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type , stream);
1553+                 quantize_src1 (
1554+                     dev[id].src1_ddf , dev[id].src1_ddq , src0->type , ne10,
1555+                     nb11/sizeof (float ), nb12/sizeof (float ), nb13/sizeof (float ),
1556+                     src1_padded_col_size, ne11, ne12, ne13, stream);
15491557                CUDA_CHECK (cudaGetLastError ());
15501558            }
15511559        }
@@ -1640,7 +1648,9 @@ static void ggml_cuda_op_mul_mat(
16401648                }
16411649
16421650                if  (quantize_src1 && !src1_is_contiguous) {
1643-                     quantize_src1 (src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1 , src1_padded_col_size, src0->type , stream);
1651+                     quantize_src1 (
1652+                         src1_ddf_i, src1_ddq_i, src0->type , ne10, ne10, ne11*ne10, ne12*ne11*ne10,
1653+                         src1_padded_col_size, src1_ncols, 1 , 1 , stream);
16441654                    CUDA_CHECK (cudaGetLastError ());
16451655                }
16461656
@@ -1878,7 +1888,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18781888static  void  ggml_cuda_mul_mat (ggml_backend_cuda_context & ctx, const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
18791889    const  bool  split = ggml_backend_buft_is_cuda_split (src0->buffer ->buft );
18801890
1881-     bool  use_mul_mat_vec   = (src0->type  == GGML_TYPE_F16 || src0->type  == GGML_TYPE_BF16)
1891+     bool  use_mul_mat_vec   = (src0->type  == GGML_TYPE_F32 || src0-> type  ==  GGML_TYPE_F16 || src0->type  == GGML_TYPE_BF16)
18821892        && src1->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F32
18831893        && src0->ne [0 ] % 2  == 0  && src1->ne [1 ] == 1 ;
18841894    bool  use_mul_mat_vec_q = ggml_is_quantized (src0->type )
@@ -1919,10 +1929,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19191929    // printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
19201930    // printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
19211931
1922-     if  (!split && use_mul_mat_vec && (src0->ne [1 ] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
1932+     if  (!split && use_mul_mat_vec && (src0->ne [1 ] <=  MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
19231933        //  the custom F16 vector kernel can be used over batched cuBLAS GEMM
19241934        //  but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
1925-         ggml_cuda_mul_mat_vec (ctx, src0, src1, dst);
1935+         ggml_cuda_mul_mat_vec (ctx, src0, src1, nullptr , dst);
1936+     } else  if  (!split && use_mul_mat_vec_q) {
1937+         ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
19261938    } else  if  (!split && src0->type  == GGML_TYPE_F16 && (src1->type  == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
19271939               && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
19281940        //  general KQ + KQV multi-batch without FlashAttention
@@ -1999,6 +2011,15 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
19992011
20002012    GGML_TENSOR_BINARY_OP_LOCALS
20012013
2014+     if  (src1->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F32 && ne2 == 1 ) {
2015+         if  (ggml_is_quantized (src0->type )) {
2016+             ggml_cuda_mul_mat_vec_q (ctx, src0, src1, ids, dst);
2017+         } else  {
2018+             ggml_cuda_mul_mat_vec (ctx, src0, src1, ids, dst);
2019+         }
2020+         return ;
2021+     }
2022+ 
20022023    GGML_ASSERT (!ggml_backend_buft_is_cuda_split (src0->buffer ->buft ) && " mul_mat_id does not support split buffers"  );
20032024
20042025    cudaStream_t stream = ctx.stream ();
@@ -2035,97 +2056,75 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
20352056    dst_row.nb [2 ] = nb1;
20362057    dst_row.nb [3 ] = nb1;
20372058
2038-     if  (ne12 == 1 ) {
2039-         for  (int64_t  iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2040-             for  (int64_t  id = 0 ; id < n_ids; id++) {
2041-                 const  int32_t  i02 = *(const  int32_t  *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2042- 
2043-                 GGML_ASSERT (i02 >= 0  && i02 < n_as);
2044- 
2045-                 const  int64_t  i11 = id % ne11;
2046-                 const  int64_t  i12 = iid1;
2047- 
2048-                 const  int64_t  i1 = id;
2049-                 const  int64_t  i2 = i12;
2050- 
2051-                 src0_row.data  = src0_original + i02*nb02;
2052-                 src1_row.data  = src1_original + i11*nb11 + i12*nb12;
2053-                 dst_row.data   =  dst_original + i1*nb1   + i2*nb2;
2054- 
2055-                 ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
2056-             }
2057-         }
2058-     } else  {
2059-         ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (src1));
2060-         ggml_cuda_pool_alloc<char >  dst_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (dst));
2061- 
2062-         src1_row.data  = src1_contiguous.get ();
2063-         dst_row.data   =  dst_contiguous.get ();
2059+     ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (src1));
2060+     ggml_cuda_pool_alloc<char >  dst_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (dst));
20642061
2065-          for  ( int64_t  i02 =  0 ; i02 < n_as; i02++) { 
2066-              int64_t  num_src1_rows =  0 ;
2062+     src1_row. data  = src1_contiguous. get (); 
2063+     dst_row. data   =  dst_contiguous. get () ;
20672064
2068-             for  (int64_t  iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2069-                 for  (int64_t  id = 0 ; id < n_ids; id++) {
2070-                     const  int32_t  row_id_i = *(const  int32_t  *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2065+     for  (int64_t  i02 = 0 ; i02 < n_as; i02++) {
2066+         int64_t  num_src1_rows = 0 ;
20712067
2072-                     GGML_ASSERT (row_id_i >= 0  && row_id_i < n_as);
2068+         for  (int64_t  iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2069+             for  (int64_t  id = 0 ; id < n_ids; id++) {
2070+                 const  int32_t  row_id_i = *(const  int32_t  *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
20732071
2074-                     if  (row_id_i != i02) {
2075-                         continue ;
2076-                     }
2072+                 GGML_ASSERT (row_id_i >= 0  && row_id_i < n_as);
20772073
2078-                     num_src1_rows++;
2074+                 if  (row_id_i != i02) {
2075+                     continue ;
20792076                }
2080-             }
20812077
2082-             if  (num_src1_rows == 0 ) {
2083-                 continue ;
2078+                 num_src1_rows++;
20842079            }
2080+         }
20852081
2086-             ggml_cuda_pool_alloc< int >  dev_cur_src1_row (ctx. pool (),  1 ); 
2087-             ggml_cuda_pool_alloc<mmid_row_mapping>  dev_row_mapping (ctx. pool (), num_src1_rows) ;
2088-              CUDA_CHECK ( cudaMemsetAsync (dev_cur_src1_row. get (),  0 ,  sizeof ( int ), stream)); 
2082+         if  (num_src1_rows ==  0 ) { 
2083+             continue ;
2084+         } 
20892085
2090-             {
2091-                 dim3  block_dims (std::min ((unsigned  int )ne10, 768u ));
2092-                 dim3  grid_dims (ids->ne [1 ], n_ids);
2093-                 k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2094-                         src1_original, src1_contiguous.get (),
2095-                         dev_cur_src1_row.get (), dev_row_mapping.get (),
2096-                         ids_dev, i02, ids->nb [1 ], ids->nb [0 ],
2097-                         ne11, ne10,
2098-                         nb11, nb12);
2099-                 CUDA_CHECK (cudaGetLastError ());
2100-             }
2086+         ggml_cuda_pool_alloc<int > dev_cur_src1_row (ctx.pool (), 1 );
2087+         ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool (), num_src1_rows);
2088+         CUDA_CHECK (cudaMemsetAsync (dev_cur_src1_row.get (), 0 , sizeof (int ), stream));
2089+ 
2090+         {
2091+             dim3  block_dims (std::min ((unsigned  int )ne10, 768u ));
2092+             dim3  grid_dims (ids->ne [1 ], n_ids);
2093+             k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2094+                     src1_original, src1_contiguous.get (),
2095+                     dev_cur_src1_row.get (), dev_row_mapping.get (),
2096+                     ids_dev, i02, ids->nb [1 ], ids->nb [0 ],
2097+                     ne11, ne10,
2098+                     nb11, nb12);
2099+             CUDA_CHECK (cudaGetLastError ());
2100+         }
21012101
2102-              src0_row.data  = src0_original + i02*nb02;
2102+         src0_row.data  = src0_original + i02*nb02;
21032103
2104-              GGML_ASSERT (nb11 == sizeof (float )*ne10);
2105-              GGML_ASSERT (nb1 == sizeof (float )*ne0);
2104+         GGML_ASSERT (nb11 == sizeof (float )*ne10);
2105+         GGML_ASSERT (nb1 == sizeof (float )*ne0);
21062106
2107-              src1_row.ne [1 ] = num_src1_rows;
2108-              src1_row.nb [1 ] = nb11;
2109-              src1_row.nb [2 ] = num_src1_rows*nb11;
2110-              src1_row.nb [3 ] = num_src1_rows*nb11;
2107+         src1_row.ne [1 ] = num_src1_rows;
2108+         src1_row.nb [1 ] = nb11;
2109+         src1_row.nb [2 ] = num_src1_rows*nb11;
2110+         src1_row.nb [3 ] = num_src1_rows*nb11;
21112111
2112-              dst_row.ne [1 ] = num_src1_rows;
2113-              dst_row.nb [1 ] = nb1;
2114-              dst_row.nb [2 ] = num_src1_rows*nb1;
2115-              dst_row.nb [3 ] = num_src1_rows*nb1;
2112+         dst_row.ne [1 ] = num_src1_rows;
2113+         dst_row.nb [1 ] = nb1;
2114+         dst_row.nb [2 ] = num_src1_rows*nb1;
2115+         dst_row.nb [3 ] = num_src1_rows*nb1;
21162116
2117-              ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
2117+         ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
21182118
2119-             {
2120-                 dim3  block_dims (std::min ((unsigned  int )ne0, 768u ));
2121-                 dim3  grid_dims (num_src1_rows);
2122-                 k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2123-                         dst_original, dst_contiguous.get (),
2124-                         dev_row_mapping.get (),
2125-                         ne0,
2126-                         nb1, nb2);
2127-                 CUDA_CHECK (cudaGetLastError ());
2128-             }
2119+         {
2120+             dim3  block_dims (std::min ((unsigned  int )ne0, 768u ));
2121+             dim3  grid_dims (num_src1_rows);
2122+             k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2123+                     dst_original, dst_contiguous.get (),
2124+                     dev_row_mapping.get (),
2125+                     ne0,
2126+                     nb1, nb2);
2127+             CUDA_CHECK (cudaGetLastError ());
21292128        }
21302129    }
21312130}
@@ -2489,7 +2488,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
24892488#endif 
24902489        }
24912490
2492-         if  (node->op  == GGML_OP_MUL_MAT_ID) {
2491+         if  (node->op  == GGML_OP_MUL_MAT_ID && node-> ne [ 2 ] !=  1 ) {
24932492            use_cuda_graph = false ; //  This node type is not supported by CUDA graph capture
24942493#ifndef  NDEBUG
24952494            GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to unsupported node type\n "  , __func__);
@@ -3203,9 +3202,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32033202        }
32043203        case  GGML_OP_ROPE:
32053204        case  GGML_OP_ROPE_BACK: {
3206-             const  size_t  ts = ggml_type_size (op->src [0 ]->type );
3207-             const  int64_t  ne0_012 = op->src [0 ]->ne [0 ] * op->src [0 ]->ne [1 ] * op->src [0 ]->ne [2 ];
3208-             return  op->src [0 ]->nb [0 ] == ts && op->src [0 ]->nb [3 ] == ne0_012*ts;
3205+             return  op->src [0 ]->nb [0 ] == ggml_type_size (op->src [0 ]->type ) && ggml_is_contiguous_2 (op->src [0 ]);
32093206        }
32103207        case  GGML_OP_IM2COL:
32113208        case  GGML_OP_POOL_2D:
0 commit comments