44typedef  void  (*set_rows_kernel_t )(const  char  * src, char  * dst);
55
66//  Generic quantized set_rows kernel template
7- template <typename  block_type, int  qk, void  (*quantize_func)(const  float *, block_type*)>
7+ template <typename  idx_t ,  typename   block_type, int  qk, void  (*quantize_func)(const  float *, block_type*)>
88static  __global__  void  k_set_rows_quant (
9-         const  float  * __restrict__  src0, const  int64_t  * __restrict__  src1, block_type * __restrict__  dst,
9+         const  float  * __restrict__  src0, const  idx_t  * __restrict__  src1, block_type * __restrict__  dst,
1010        const  int64_t  ne00, const  int64_t  ne01, const  int64_t  ne02, const  int64_t  ne03,
1111        const  int64_t  ne10, const  int64_t  ne11, const  int64_t  ne12, const  int64_t  ne13,
1212        const  int64_t  s01, const  int64_t  s02, const  int64_t  s03,
@@ -45,9 +45,9 @@ static __global__ void k_set_rows_quant(
4545}
4646
4747//  Template dispatch function for quantized set_rows
48- template <typename  block_type, int  qk, void  (*quantize_func)(const  float *, block_type*)>
48+ template <typename  idx_t ,  typename   block_type, int  qk, void  (*quantize_func)(const  float *, block_type*)>
4949static  void  set_rows_cuda_quant (
50-         const  float  * src0_d, const  int64_t  * src1_d, block_type * dst_d,
50+         const  float  * src0_d, const  idx_t  * src1_d, block_type * dst_d,
5151        const  int64_t  ne00, const  int64_t  ne01, const  int64_t  ne02, const  int64_t  ne03,
5252        const  int64_t  ne10, const  int64_t  ne11, const  int64_t  ne12, const  int64_t  ne13,
5353        const  size_t  nb01, const  size_t  nb02, const  size_t  nb03,
@@ -64,15 +64,15 @@ static void set_rows_cuda_quant(
6464    const  int64_t  s01 = nb01/sizeof (float );
6565    const  int64_t  s02 = nb02/sizeof (float );
6666    const  int64_t  s03 = nb03/sizeof (float );
67-     const  int64_t  s10 = nb10/sizeof (int64_t );
68-     const  int64_t  s11 = nb11/sizeof (int64_t );
69-     const  int64_t  s12 = nb12/sizeof (int64_t );
67+     const  int64_t  s10 = nb10/sizeof (idx_t );
68+     const  int64_t  s11 = nb11/sizeof (idx_t );
69+     const  int64_t  s12 = nb12/sizeof (idx_t );
7070    const  int64_t  s1  = nb1;
7171    const  int64_t  s2  = nb2;
7272    const  int64_t  s3  = nb3;
7373
7474    if  (ne_total > 0 ) {
75-         k_set_rows_quant<block_type, qk, quantize_func><<<grid_size, block_size, 0 , stream>>> (
75+         k_set_rows_quant<idx_t ,  block_type, qk, quantize_func><<<grid_size, block_size, 0 , stream>>> (
7676            src0_d, src1_d, dst_d,
7777            ne00, ne01, ne02, ne03,
7878            ne10, ne11, ne12, ne13,
@@ -82,9 +82,9 @@ static void set_rows_cuda_quant(
8282    }
8383}
8484
85- template <typename  src_t , typename  dst_t >
85+ template <typename  src_t , typename  idx_t ,  typename   dst_t >
8686static  __global__  void  k_set_rows (
87-         const  src_t  * __restrict__  src0, const  int64_t  * __restrict__  src1, dst_t  * __restrict__  dst,
87+         const  src_t  * __restrict__  src0, const  idx_t  * __restrict__  src1, dst_t  * __restrict__  dst,
8888        const  int64_t  ne00, const  int64_t  ne01, const  int64_t  ne02, const  int64_t  ne03,
8989        const  int64_t  ne10, const  int64_t  ne11, const  int64_t  ne12, const  int64_t  ne13,
9090        const  int64_t  s01, const  int64_t  s02, const  int64_t  s03,
@@ -118,9 +118,9 @@ static __global__ void k_set_rows(
118118    GGML_UNUSED (ne13);
119119}
120120
121- template <typename  src_t , typename  dst_t >
121+ template <typename  src_t , typename  idx_t ,  typename   dst_t >
122122static  void  set_rows_cuda (
123-         const  src_t  * src0_d, const  int64_t  * src1_d, dst_t  * dst_d,
123+         const  src_t  * src0_d, const  idx_t  * src1_d, dst_t  * dst_d,
124124        const  int64_t  ne00, const  int64_t  ne01, const  int64_t  ne02, const  int64_t  ne03,
125125        const  int64_t  ne10, const  int64_t  ne11, const  int64_t  ne12, const  int64_t  ne13,
126126        const  size_t  nb01, const  size_t  nb02, const  size_t  nb03,
@@ -137,9 +137,9 @@ static void set_rows_cuda(
137137    const  int64_t  s01 = nb01/sizeof (src_t );
138138    const  int64_t  s02 = nb02/sizeof (src_t );
139139    const  int64_t  s03 = nb03/sizeof (src_t );
140-     const  int64_t  s10 = nb10/sizeof (int64_t );
141-     const  int64_t  s11 = nb11/sizeof (int64_t );
142-     const  int64_t  s12 = nb12/sizeof (int64_t );
140+     const  int64_t  s10 = nb10/sizeof (idx_t );
141+     const  int64_t  s11 = nb11/sizeof (idx_t );
142+     const  int64_t  s12 = nb12/sizeof (idx_t );
143143    const  int64_t  s1  = nb1/sizeof (dst_t );
144144    const  int64_t  s2  = nb2/sizeof (dst_t );
145145    const  int64_t  s3  = nb3/sizeof (dst_t );
@@ -155,23 +155,16 @@ static void set_rows_cuda(
155155    }
156156}
157157
158- 
159- void  ggml_cuda_op_set_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160-     const  ggml_tensor * src0 = dst->src [0 ];
161-     const  ggml_tensor * src1 = dst->src [1 ];
162- 
163-     GGML_ASSERT (src0->type  == GGML_TYPE_F32);
164-     GGML_ASSERT (src1->type  == GGML_TYPE_I64);
158+ template <typename  src_t , typename  idx_t >
159+ static  void  set_rows_cuda (ggml_backend_cuda_context & ctx, const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
160+     const  src_t  * src0_d = (const  src_t  *)src0->data ;
161+     const  idx_t  * src1_d = (const  idx_t  *)src1->data ;
165162
166163    GGML_TENSOR_BINARY_OP_LOCALS
167164
168-     const  float  * src0_d   = (const  float  *)src0->data ;
169-     const  int64_t  * src1_d = (const  int64_t  *)src1->data ;
170- 
171165    cudaStream_t stream = ctx.stream ();
172166
173167
174- 
175168    if  (dst->type  == GGML_TYPE_F32) {
176169        set_rows_cuda (
177170            src0_d, src1_d, (float *)dst->data ,
@@ -203,7 +196,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
203196            stream
204197        );
205198    } else  if  (dst->type  == GGML_TYPE_Q4_0) {
206-         set_rows_cuda_quant<block_q4_0, QK4_0, quantize_f32_q4_0_block>(
199+         set_rows_cuda_quant<idx_t ,  block_q4_0, QK4_0, quantize_f32_q4_0_block>(
207200            src0_d, src1_d, (block_q4_0*)dst->data ,
208201            ne00, ne01, ne02, ne03,
209202            ne10, ne11, ne12, ne13,
@@ -213,7 +206,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
213206            stream
214207        );
215208    } else  if  (dst->type  == GGML_TYPE_Q4_1) {
216-         set_rows_cuda_quant<block_q4_1, QK4_1, quantize_f32_q4_1_block>(
209+         set_rows_cuda_quant<idx_t ,  block_q4_1, QK4_1, quantize_f32_q4_1_block>(
217210            src0_d, src1_d, (block_q4_1*)dst->data ,
218211            ne00, ne01, ne02, ne03,
219212            ne10, ne11, ne12, ne13,
@@ -223,7 +216,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
223216            stream
224217        );
225218    } else  if  (dst->type  == GGML_TYPE_Q5_0) {
226-         set_rows_cuda_quant<block_q5_0, QK5_0, quantize_f32_q5_0_block>(
219+         set_rows_cuda_quant<idx_t ,  block_q5_0, QK5_0, quantize_f32_q5_0_block>(
227220            src0_d, src1_d, (block_q5_0*)dst->data ,
228221            ne00, ne01, ne02, ne03,
229222            ne10, ne11, ne12, ne13,
@@ -233,7 +226,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
233226            stream
234227        );
235228    } else  if  (dst->type  == GGML_TYPE_Q5_1) {
236-         set_rows_cuda_quant<block_q5_1, QK5_1, quantize_f32_q5_1_block>(
229+         set_rows_cuda_quant<idx_t ,  block_q5_1, QK5_1, quantize_f32_q5_1_block>(
237230            src0_d, src1_d, (block_q5_1*)dst->data ,
238231            ne00, ne01, ne02, ne03,
239232            ne10, ne11, ne12, ne13,
@@ -243,7 +236,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
243236            stream
244237        );
245238    } else  if  (dst->type  == GGML_TYPE_Q8_0) {
246-         set_rows_cuda_quant<block_q8_0, QK8_0, quantize_f32_q8_0_block>(
239+         set_rows_cuda_quant<idx_t ,  block_q8_0, QK8_0, quantize_f32_q8_0_block>(
247240            src0_d, src1_d, (block_q8_0*)dst->data ,
248241            ne00, ne01, ne02, ne03,
249242            ne10, ne11, ne12, ne13,
@@ -253,7 +246,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
253246            stream
254247        );
255248    } else  if  (dst->type  == GGML_TYPE_IQ4_NL) {
256-         set_rows_cuda_quant<block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
249+         set_rows_cuda_quant<idx_t ,  block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
257250            src0_d, src1_d, (block_iq4_nl*)dst->data ,
258251            ne00, ne01, ne02, ne03,
259252            ne10, ne11, ne12, ne13,
@@ -266,3 +259,18 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
266259        GGML_ABORT (" unsupported type %s" ggml_type_name (dst->type ));
267260    }
268261}
262+ 
263+ 
264+ void  ggml_cuda_op_set_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
265+     const  ggml_tensor * src0 = dst->src [0 ];
266+     const  ggml_tensor * src1 = dst->src [1 ];
267+ 
268+     GGML_ASSERT (src0->type  == GGML_TYPE_F32);
269+     GGML_ASSERT (src1->type  == GGML_TYPE_I64 || src1->type  == GGML_TYPE_I32);
270+ 
271+     if  (src1->type  == GGML_TYPE_I64) {
272+         set_rows_cuda<float , int64_t >(ctx, src0, src1, dst);
273+     } else  {
274+         set_rows_cuda<float , int32_t >(ctx, src0, src1, dst);
275+     }
276+ }
0 commit comments