@@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12311231
12321232 if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT) {
12331233 // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1234- ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool ());
1234+ ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool (id ));
12351235 if (src0->type != GGML_TYPE_F16) {
12361236 const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src0->type );
12371237 GGML_ASSERT (to_fp16_cuda != nullptr );
@@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12411241 }
12421242 const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get ();
12431243
1244- ggml_cuda_pool_alloc<half> src1_as_f16 (ctx.pool ());
1244+ ggml_cuda_pool_alloc<half> src1_as_f16 (ctx.pool (id ));
12451245 if (src1->type != GGML_TYPE_F16) {
12461246 const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src1->type );
12471247 GGML_ASSERT (to_fp16_cuda != nullptr );
@@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12501250 to_fp16_cuda (src1_ddf_i, src1_as_f16.get (), ne, stream);
12511251 }
12521252 const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get ();
1253- ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (), row_diff*src1_ncols);
1253+ ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (id ), row_diff*src1_ncols);
12541254
12551255 const half alpha_f16 = 1 .0f ;
12561256 const half beta_f16 = 0 .0f ;
@@ -1960,20 +1960,73 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19601960 }
19611961}
19621962
1963+ struct mmid_row_mapping {
1964+ int32_t i1;
1965+ int32_t i2;
1966+ };
1967+
1968+ static __global__ void k_copy_src1_to_contiguous (const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
1969+ int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
1970+ const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
1971+ int64_t ne11, int64_t ne10,
1972+ size_t nb11, size_t nb12) {
1973+ int32_t iid1 = blockIdx .x ;
1974+ int32_t id = blockIdx .y ;
1975+
1976+ const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
1977+
1978+ if (row_id_i != i02) {
1979+ return ;
1980+ }
1981+
1982+ const int64_t i11 = id % ne11;
1983+ const int64_t i12 = iid1;
1984+
1985+ __shared__ int src1_row;
1986+ if (threadIdx .x == 0 ) {
1987+ src1_row = atomicAdd (cur_src1_row, 1 );
1988+ row_mapping[src1_row] = {id, iid1};
1989+ }
1990+ __syncthreads ();
1991+
1992+ const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
1993+ float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
1994+
1995+ for (int i = threadIdx .x ; i < ne10; i += blockDim .x ) {
1996+ src1_row_contiguous[i] = src1_row_original[i];
1997+ }
1998+ }
1999+
2000+ static __global__ void k_copy_dst_from_contiguous (char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
2001+ const mmid_row_mapping * __restrict__ row_mapping,
2002+ int64_t ne0,
2003+ size_t nb1, size_t nb2) {
2004+ int32_t i = blockIdx .x ;
2005+
2006+ const int32_t i1 = row_mapping[i].i1 ;
2007+ const int32_t i2 = row_mapping[i].i2 ;
2008+
2009+ const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
2010+ float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
2011+
2012+ for (int j = threadIdx .x ; j < ne0; j += blockDim .x ) {
2013+ dst_row_original[j] = dst_row_contiguous[j];
2014+ }
2015+ }
2016+
19632017static void ggml_cuda_mul_mat_id (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
19642018 const ggml_tensor * src0 = dst->src [0 ];
19652019 const ggml_tensor * src1 = dst->src [1 ];
19662020 const ggml_tensor * ids = dst->src [2 ];
19672021
2022+ GGML_TENSOR_BINARY_OP_LOCALS
2023+
19682024 GGML_ASSERT (!ggml_backend_buffer_is_cuda_split (src0->buffer ) && " mul_mat_id does not support split buffers" );
19692025
19702026 cudaStream_t stream = ctx.stream ();
19712027
1972- const size_t nb11 = src1->nb [1 ];
1973- const size_t nb1 = dst->nb [1 ];
1974-
1975- const int32_t id = ((int32_t *) dst->op_params )[0 ];
1976- const int32_t n_as = src0->ne [2 ];
2028+ const int64_t n_as = ne02;
2029+ const int64_t n_ids = ids->ne [0 ];
19772030
19782031 std::vector<char > ids_host (ggml_nbytes (ids));
19792032 const char * ids_dev = (const char *) ids->data ;
@@ -1982,27 +2035,47 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
19822035
19832036 ggml_tensor src0_row = *src0;
19842037 ggml_tensor src1_row = *src1;
1985- ggml_tensor dst_row = *dst;
2038+ ggml_tensor dst_row = *dst;
19862039
19872040 char * src0_original = (char *) src0->data ;
19882041 char * src1_original = (char *) src1->data ;
19892042 char * dst_original = (char *) dst->data ;
19902043
19912044 src0_row.ne [2 ] = 1 ;
19922045 src0_row.ne [3 ] = 1 ;
1993- src0_row.nb [3 ] = src0-> nb [ 2 ] ;
2046+ src0_row.nb [3 ] = nb02 ;
19942047
1995- if (src1->ne [1 ] == 1 ) {
1996- for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
1997- const int32_t row_id = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
2048+ src1_row.ne [1 ] = 1 ;
2049+ src1_row.ne [2 ] = 1 ;
2050+ src1_row.ne [3 ] = 1 ;
2051+ src1_row.nb [2 ] = nb11;
2052+ src1_row.nb [3 ] = nb11;
19982053
1999- GGML_ASSERT (row_id >= 0 && row_id < n_as);
2054+ dst_row.ne [1 ] = 1 ;
2055+ dst_row.ne [2 ] = 1 ;
2056+ dst_row.ne [3 ] = 1 ;
2057+ dst_row.nb [2 ] = nb1;
2058+ dst_row.nb [3 ] = nb1;
20002059
2001- src0_row.data = src0_original + row_id*src0->nb [2 ];
2002- src1_row.data = src1_original + i01*src1->nb [1 ];
2003- dst_row.data = dst_original + i01*dst->nb [1 ];
2060+ if (ne12 == 1 ) {
2061+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2062+ for (int64_t id = 0 ; id < n_ids; id++) {
2063+ const int32_t i02 = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
20042064
2005- ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
2065+ GGML_ASSERT (i02 >= 0 && i02 < n_as);
2066+
2067+ const int64_t i11 = id % ne11;
2068+ const int64_t i12 = iid1;
2069+
2070+ const int64_t i1 = id;
2071+ const int64_t i2 = i12;
2072+
2073+ src0_row.data = src0_original + i02*nb02;
2074+ src1_row.data = src1_original + i11*nb11 + i12*nb12;
2075+ dst_row.data = dst_original + i1*nb1 + i2*nb2;
2076+
2077+ ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
2078+ }
20062079 }
20072080 } else {
20082081 ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (src1));
@@ -2011,54 +2084,69 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
20112084 src1_row.data = src1_contiguous.get ();
20122085 dst_row.data = dst_contiguous.get ();
20132086
2014- for (int32_t row_id = 0 ; row_id < n_as; ++row_id ) {
2087+ for (int64_t i02 = 0 ; i02 < n_as; i02++ ) {
20152088 int64_t num_src1_rows = 0 ;
2016- for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
2017- const int32_t row_id_i = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
20182089
2019- if (row_id_i != row_id ) {
2020- continue ;
2021- }
2090+ for ( int64_t iid1 = 0 ; iid1 < ids-> ne [ 1 ]; iid1++ ) {
2091+ for ( int64_t id = 0 ; id < n_ids; id++) {
2092+ const int32_t row_id_i = *( const int32_t *) (ids_host. data () + iid1*ids-> nb [ 1 ] + id*ids-> nb [ 0 ]);
20222093
2023- GGML_ASSERT (row_id >= 0 && row_id < n_as);
2094+ GGML_ASSERT (row_id_i >= 0 && row_id_i < n_as);
20242095
2025- CUDA_CHECK (cudaMemcpyAsync (src1_contiguous.get () + num_src1_rows*nb11, src1_original + i01*nb11,
2026- nb11, cudaMemcpyDeviceToDevice, stream));
2027- num_src1_rows++;
2096+ if (row_id_i != i02) {
2097+ continue ;
2098+ }
2099+
2100+ num_src1_rows++;
2101+ }
20282102 }
20292103
20302104 if (num_src1_rows == 0 ) {
20312105 continue ;
20322106 }
20332107
2034- src0_row.data = src0_original + row_id*src0->nb [2 ];
2108+ ggml_cuda_pool_alloc<int > dev_cur_src1_row (ctx.pool (), 1 );
2109+ ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool (), num_src1_rows);
2110+ CUDA_CHECK (cudaMemsetAsync (dev_cur_src1_row.get (), 0 , sizeof (int ), stream));
20352111
2036- src1_row.ne [1 ] = num_src1_rows;
2037- dst_row.ne [1 ] = num_src1_rows;
2112+ {
2113+ dim3 block_dims (std::min ((unsigned int )ne10, 768u ));
2114+ dim3 grid_dims (ids->ne [1 ], n_ids);
2115+ k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2116+ src1_original, src1_contiguous.get (),
2117+ dev_cur_src1_row.get (), dev_row_mapping.get (),
2118+ ids_dev, i02, ids->nb [1 ], ids->nb [0 ],
2119+ ne11, ne10,
2120+ nb11, nb12);
2121+ CUDA_CHECK (cudaGetLastError ());
2122+ }
2123+
2124+ src0_row.data = src0_original + i02*nb02;
20382125
2126+ GGML_ASSERT (nb11 == sizeof (float )*ne10);
2127+ GGML_ASSERT (nb1 == sizeof (float )*ne0);
2128+
2129+ src1_row.ne [1 ] = num_src1_rows;
20392130 src1_row.nb [1 ] = nb11;
20402131 src1_row.nb [2 ] = num_src1_rows*nb11;
20412132 src1_row.nb [3 ] = num_src1_rows*nb11;
20422133
2134+ dst_row.ne [1 ] = num_src1_rows;
20432135 dst_row.nb [1 ] = nb1;
20442136 dst_row.nb [2 ] = num_src1_rows*nb1;
20452137 dst_row.nb [3 ] = num_src1_rows*nb1;
20462138
20472139 ggml_cuda_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
20482140
2049- num_src1_rows = 0 ;
2050- for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
2051- const int32_t row_id_i = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
2052-
2053- if (row_id_i != row_id) {
2054- continue ;
2055- }
2056-
2057- GGML_ASSERT (row_id >= 0 && row_id < n_as);
2058-
2059- CUDA_CHECK (cudaMemcpyAsync (dst_original + i01*nb1, dst_contiguous.get () + num_src1_rows*nb1,
2060- nb1, cudaMemcpyDeviceToDevice, stream));
2061- num_src1_rows++;
2141+ {
2142+ dim3 block_dims (std::min ((unsigned int )ne0, 768u ));
2143+ dim3 grid_dims (num_src1_rows);
2144+ k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2145+ dst_original, dst_contiguous.get (),
2146+ dev_row_mapping.get (),
2147+ ne0,
2148+ nb1, nb2);
2149+ CUDA_CHECK (cudaGetLastError ());
20622150 }
20632151 }
20642152 }
@@ -2487,7 +2575,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
24872575GGML_CALL static bool ggml_backend_cuda_offload_op (ggml_backend_t backend, const ggml_tensor * op) {
24882576 const int min_batch_size = 32 ;
24892577
2490- return op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
2578+ return (op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2579+ (op->ne [2 ] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
24912580
24922581 GGML_UNUSED (backend);
24932582}
0 commit comments