@@ -155,220 +155,122 @@ static void set_rows_cuda(
155155    }
156156}
157157
158- 
159- void  ggml_cuda_op_set_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160-     const  ggml_tensor * src0 = dst->src [0 ];
161-     const  ggml_tensor * src1 = dst->src [1 ];
162- 
163-     GGML_ASSERT (src0->type  == GGML_TYPE_F32);
164-     GGML_ASSERT (src1->type  == GGML_TYPE_I64 || src1->type  == GGML_TYPE_I32);
158+ template <typename  src_t , typename  idx_t >
159+ static  void  set_rows_cuda (ggml_backend_cuda_context & ctx, const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
160+     const  src_t  * src0_d = (const  src_t  *)src0->data ;
161+     const  idx_t  * src1_d = (const  idx_t  *)src1->data ;
165162
166163    GGML_TENSOR_BINARY_OP_LOCALS
167164
168-     const  float  * src0_d = (const  float  *)src0->data ;
169- 
170165    cudaStream_t stream = ctx.stream ();
171166
172167
173168    if  (dst->type  == GGML_TYPE_F32) {
174-         if  (src1->type  == GGML_TYPE_I64) {
175-             set_rows_cuda (
176-                 src0_d, (const  int64_t  *)src1->data , (float *)dst->data ,
177-                 ne00, ne01, ne02, ne03,
178-                 ne10, ne11, ne12, ne13,
179-                 nb01, nb02, nb03,
180-                 nb10, nb11, nb12,
181-                 nb1, nb2, nb3,
182-                 stream
183-             );
184-         } else  {
185-             set_rows_cuda (
186-                 src0_d, (const  int32_t  *)src1->data , (float *)dst->data ,
187-                 ne00, ne01, ne02, ne03,
188-                 ne10, ne11, ne12, ne13,
189-                 nb01, nb02, nb03,
190-                 nb10, nb11, nb12,
191-                 nb1, nb2, nb3,
192-                 stream
193-             );
194-         }
169+         set_rows_cuda (
170+             src0_d, src1_d, (float *)dst->data ,
171+             ne00, ne01, ne02, ne03,
172+             ne10, ne11, ne12, ne13,
173+             nb01, nb02, nb03,
174+             nb10, nb11, nb12,
175+             nb1, nb2, nb3,
176+             stream
177+         );
195178    } else  if  (dst->type  == GGML_TYPE_F16) {
196-         if  (src1->type  == GGML_TYPE_I64) {
197-             set_rows_cuda (
198-                 src0_d, (const  int64_t  *)src1->data , (half*)dst->data ,
199-                 ne00, ne01, ne02, ne03,
200-                 ne10, ne11, ne12, ne13,
201-                 nb01, nb02, nb03,
202-                 nb10, nb11, nb12,
203-                 nb1, nb2, nb3,
204-                 stream
205-             );
206-         } else  {
207-             set_rows_cuda (
208-                 src0_d, (const  int32_t  *)src1->data , (half*)dst->data ,
209-                 ne00, ne01, ne02, ne03,
210-                 ne10, ne11, ne12, ne13,
211-                 nb01, nb02, nb03,
212-                 nb10, nb11, nb12,
213-                 nb1, nb2, nb3,
214-                 stream
215-             );
216-         }
179+         set_rows_cuda (
180+             src0_d, src1_d, (half*)dst->data ,
181+             ne00, ne01, ne02, ne03,
182+             ne10, ne11, ne12, ne13,
183+             nb01, nb02, nb03,
184+             nb10, nb11, nb12,
185+             nb1, nb2, nb3,
186+             stream
187+         );
217188    } else  if  (dst->type  == GGML_TYPE_BF16) {
218-         if  (src1->type  == GGML_TYPE_I64) {
219-             set_rows_cuda (
220-                 src0_d, (const  int64_t  *)src1->data , (nv_bfloat16*)dst->data ,
221-                 ne00, ne01, ne02, ne03,
222-                 ne10, ne11, ne12, ne13,
223-                 nb01, nb02, nb03,
224-                 nb10, nb11, nb12,
225-                 nb1, nb2, nb3,
226-                 stream
227-             );
228-         } else  {
229-             set_rows_cuda (
230-                 src0_d, (const  int32_t  *)src1->data , (nv_bfloat16*)dst->data ,
231-                 ne00, ne01, ne02, ne03,
232-                 ne10, ne11, ne12, ne13,
233-                 nb01, nb02, nb03,
234-                 nb10, nb11, nb12,
235-                 nb1, nb2, nb3,
236-                 stream
237-             );
238-         }
189+         set_rows_cuda (
190+             src0_d, src1_d, (nv_bfloat16*)dst->data ,
191+             ne00, ne01, ne02, ne03,
192+             ne10, ne11, ne12, ne13,
193+             nb01, nb02, nb03,
194+             nb10, nb11, nb12,
195+             nb1, nb2, nb3,
196+             stream
197+         );
239198    } else  if  (dst->type  == GGML_TYPE_Q4_0) {
240-         if  (src1->type  == GGML_TYPE_I64) {
241-             set_rows_cuda_quant<int64_t , block_q4_0, QK4_0, quantize_f32_q4_0_block>(
242-                 src0_d, (const  int64_t  *)src1->data , (block_q4_0*)dst->data ,
243-                 ne00, ne01, ne02, ne03,
244-                 ne10, ne11, ne12, ne13,
245-                 nb01, nb02, nb03,
246-                 nb10, nb11, nb12,
247-                 nb1, nb2, nb3,
248-                 stream
249-             );
250-         } else  {
251-             set_rows_cuda_quant<int32_t , block_q4_0, QK4_0, quantize_f32_q4_0_block>(
252-                 src0_d, (const  int32_t  *)src1->data , (block_q4_0*)dst->data ,
253-                 ne00, ne01, ne02, ne03,
254-                 ne10, ne11, ne12, ne13,
255-                 nb01, nb02, nb03,
256-                 nb10, nb11, nb12,
257-                 nb1, nb2, nb3,
258-                 stream
259-             );
260-         }
199+         set_rows_cuda_quant<idx_t , block_q4_0, QK4_0, quantize_f32_q4_0_block>(
200+             src0_d, src1_d, (block_q4_0*)dst->data ,
201+             ne00, ne01, ne02, ne03,
202+             ne10, ne11, ne12, ne13,
203+             nb01, nb02, nb03,
204+             nb10, nb11, nb12,
205+             nb1, nb2, nb3,
206+             stream
207+         );
261208    } else  if  (dst->type  == GGML_TYPE_Q4_1) {
262-         if  (src1->type  == GGML_TYPE_I64) {
263-             set_rows_cuda_quant<int64_t , block_q4_1, QK4_1, quantize_f32_q4_1_block>(
264-                 src0_d, (const  int64_t  *)src1->data , (block_q4_1*)dst->data ,
265-                 ne00, ne01, ne02, ne03,
266-                 ne10, ne11, ne12, ne13,
267-                 nb01, nb02, nb03,
268-                 nb10, nb11, nb12,
269-                 nb1, nb2, nb3,
270-                 stream
271-             );
272-         } else  {
273-             set_rows_cuda_quant<int32_t , block_q4_1, QK4_1, quantize_f32_q4_1_block>(
274-                 src0_d, (const  int32_t  *)src1->data , (block_q4_1*)dst->data ,
275-                 ne00, ne01, ne02, ne03,
276-                 ne10, ne11, ne12, ne13,
277-                 nb01, nb02, nb03,
278-                 nb10, nb11, nb12,
279-                 nb1, nb2, nb3,
280-                 stream
281-             );
282-         }
209+         set_rows_cuda_quant<idx_t , block_q4_1, QK4_1, quantize_f32_q4_1_block>(
210+             src0_d, src1_d, (block_q4_1*)dst->data ,
211+             ne00, ne01, ne02, ne03,
212+             ne10, ne11, ne12, ne13,
213+             nb01, nb02, nb03,
214+             nb10, nb11, nb12,
215+             nb1, nb2, nb3,
216+             stream
217+         );
283218    } else  if  (dst->type  == GGML_TYPE_Q5_0) {
284-         if  (src1->type  == GGML_TYPE_I64) {
285-             set_rows_cuda_quant<int64_t , block_q5_0, QK5_0, quantize_f32_q5_0_block>(
286-                 src0_d, (const  int64_t  *)src1->data , (block_q5_0*)dst->data ,
287-                 ne00, ne01, ne02, ne03,
288-                 ne10, ne11, ne12, ne13,
289-                 nb01, nb02, nb03,
290-                 nb10, nb11, nb12,
291-                 nb1, nb2, nb3,
292-                 stream
293-             );
294-         } else  {
295-             set_rows_cuda_quant<int32_t , block_q5_0, QK5_0, quantize_f32_q5_0_block>(
296-                 src0_d, (const  int32_t  *)src1->data , (block_q5_0*)dst->data ,
297-                 ne00, ne01, ne02, ne03,
298-                 ne10, ne11, ne12, ne13,
299-                 nb01, nb02, nb03,
300-                 nb10, nb11, nb12,
301-                 nb1, nb2, nb3,
302-                 stream
303-             );
304-         }
219+         set_rows_cuda_quant<idx_t , block_q5_0, QK5_0, quantize_f32_q5_0_block>(
220+             src0_d, src1_d, (block_q5_0*)dst->data ,
221+             ne00, ne01, ne02, ne03,
222+             ne10, ne11, ne12, ne13,
223+             nb01, nb02, nb03,
224+             nb10, nb11, nb12,
225+             nb1, nb2, nb3,
226+             stream
227+         );
305228    } else  if  (dst->type  == GGML_TYPE_Q5_1) {
306-         if  (src1->type  == GGML_TYPE_I64) {
307-             set_rows_cuda_quant<int64_t , block_q5_1, QK5_1, quantize_f32_q5_1_block>(
308-                 src0_d, (const  int64_t  *)src1->data , (block_q5_1*)dst->data ,
309-                 ne00, ne01, ne02, ne03,
310-                 ne10, ne11, ne12, ne13,
311-                 nb01, nb02, nb03,
312-                 nb10, nb11, nb12,
313-                 nb1, nb2, nb3,
314-                 stream
315-             );
316-         } else  {
317-             set_rows_cuda_quant<int32_t , block_q5_1, QK5_1, quantize_f32_q5_1_block>(
318-                 src0_d, (const  int32_t  *)src1->data , (block_q5_1*)dst->data ,
319-                 ne00, ne01, ne02, ne03,
320-                 ne10, ne11, ne12, ne13,
321-                 nb01, nb02, nb03,
322-                 nb10, nb11, nb12,
323-                 nb1, nb2, nb3,
324-                 stream
325-             );
326-         }
229+         set_rows_cuda_quant<idx_t , block_q5_1, QK5_1, quantize_f32_q5_1_block>(
230+             src0_d, src1_d, (block_q5_1*)dst->data ,
231+             ne00, ne01, ne02, ne03,
232+             ne10, ne11, ne12, ne13,
233+             nb01, nb02, nb03,
234+             nb10, nb11, nb12,
235+             nb1, nb2, nb3,
236+             stream
237+         );
327238    } else  if  (dst->type  == GGML_TYPE_Q8_0) {
328-         if  (src1->type  == GGML_TYPE_I64) {
329-             set_rows_cuda_quant<int64_t , block_q8_0, QK8_0, quantize_f32_q8_0_block>(
330-                 src0_d, (const  int64_t  *)src1->data , (block_q8_0*)dst->data ,
331-                 ne00, ne01, ne02, ne03,
332-                 ne10, ne11, ne12, ne13,
333-                 nb01, nb02, nb03,
334-                 nb10, nb11, nb12,
335-                 nb1, nb2, nb3,
336-                 stream
337-             );
338-         } else  {
339-             set_rows_cuda_quant<int32_t , block_q8_0, QK8_0, quantize_f32_q8_0_block>(
340-                 src0_d, (const  int32_t  *)src1->data , (block_q8_0*)dst->data ,
341-                 ne00, ne01, ne02, ne03,
342-                 ne10, ne11, ne12, ne13,
343-                 nb01, nb02, nb03,
344-                 nb10, nb11, nb12,
345-                 nb1, nb2, nb3,
346-                 stream
347-             );
348-         }
239+         set_rows_cuda_quant<idx_t , block_q8_0, QK8_0, quantize_f32_q8_0_block>(
240+             src0_d, src1_d, (block_q8_0*)dst->data ,
241+             ne00, ne01, ne02, ne03,
242+             ne10, ne11, ne12, ne13,
243+             nb01, nb02, nb03,
244+             nb10, nb11, nb12,
245+             nb1, nb2, nb3,
246+             stream
247+         );
349248    } else  if  (dst->type  == GGML_TYPE_IQ4_NL) {
350-         if  (src1->type  == GGML_TYPE_I64) {
351-             set_rows_cuda_quant<int64_t , block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
352-                 src0_d, (const  int64_t  *)src1->data , (block_iq4_nl*)dst->data ,
353-                 ne00, ne01, ne02, ne03,
354-                 ne10, ne11, ne12, ne13,
355-                 nb01, nb02, nb03,
356-                 nb10, nb11, nb12,
357-                 nb1, nb2, nb3,
358-                 stream
359-             );
360-         } else  {
361-             set_rows_cuda_quant<int32_t , block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
362-                 src0_d, (const  int32_t  *)src1->data , (block_iq4_nl*)dst->data ,
363-                 ne00, ne01, ne02, ne03,
364-                 ne10, ne11, ne12, ne13,
365-                 nb01, nb02, nb03,
366-                 nb10, nb11, nb12,
367-                 nb1, nb2, nb3,
368-                 stream
369-             );
370-         }
249+         set_rows_cuda_quant<idx_t , block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
250+             src0_d, src1_d, (block_iq4_nl*)dst->data ,
251+             ne00, ne01, ne02, ne03,
252+             ne10, ne11, ne12, ne13,
253+             nb01, nb02, nb03,
254+             nb10, nb11, nb12,
255+             nb1, nb2, nb3,
256+             stream
257+         );
371258    } else  {
372259        GGML_ABORT (" unsupported type %s" ggml_type_name (dst->type ));
373260    }
374261}
262+ 
263+ 
264+ void  ggml_cuda_op_set_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
265+     const  ggml_tensor * src0 = dst->src [0 ];
266+     const  ggml_tensor * src1 = dst->src [1 ];
267+ 
268+     GGML_ASSERT (src0->type  == GGML_TYPE_F32);
269+     GGML_ASSERT (src1->type  == GGML_TYPE_I64 || src1->type  == GGML_TYPE_I32);
270+ 
271+     if  (src1->type  == GGML_TYPE_I64) {
272+         set_rows_cuda<float , int64_t >(ctx, src0, src1, dst);
273+     } else  {
274+         set_rows_cuda<float , int32_t >(ctx, src0, src1, dst);
275+     }
276+ }
0 commit comments