Skip to content

Commit aa76826

Browse files
authored
deduplicate cuda/sycl and test-fix [no-ci]
1 parent 9899182 commit aa76826

File tree

3 files changed

+161
-310
lines changed

3 files changed

+161
-310
lines changed

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 100 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -155,220 +155,122 @@ static void set_rows_cuda(
155155
}
156156
}
157157

158-
159-
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160-
const ggml_tensor * src0 = dst->src[0];
161-
const ggml_tensor * src1 = dst->src[1];
162-
163-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
164-
GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
158+
template<typename src_t, typename idx_t>
159+
static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
160+
const src_t * src0_d = (const src_t *)src0->data;
161+
const idx_t * src1_d = (const idx_t *)src1->data;
165162

166163
GGML_TENSOR_BINARY_OP_LOCALS
167164

168-
const float * src0_d = (const float *)src0->data;
169-
170165
cudaStream_t stream = ctx.stream();
171166

172167

173168
if (dst->type == GGML_TYPE_F32) {
174-
if (src1->type == GGML_TYPE_I64) {
175-
set_rows_cuda(
176-
src0_d, (const int64_t *)src1->data, (float*)dst->data,
177-
ne00, ne01, ne02, ne03,
178-
ne10, ne11, ne12, ne13,
179-
nb01, nb02, nb03,
180-
nb10, nb11, nb12,
181-
nb1, nb2, nb3,
182-
stream
183-
);
184-
} else {
185-
set_rows_cuda(
186-
src0_d, (const int32_t *)src1->data, (float*)dst->data,
187-
ne00, ne01, ne02, ne03,
188-
ne10, ne11, ne12, ne13,
189-
nb01, nb02, nb03,
190-
nb10, nb11, nb12,
191-
nb1, nb2, nb3,
192-
stream
193-
);
194-
}
169+
set_rows_cuda(
170+
src0_d, src1_d, (float*)dst->data,
171+
ne00, ne01, ne02, ne03,
172+
ne10, ne11, ne12, ne13,
173+
nb01, nb02, nb03,
174+
nb10, nb11, nb12,
175+
nb1, nb2, nb3,
176+
stream
177+
);
195178
} else if (dst->type == GGML_TYPE_F16) {
196-
if (src1->type == GGML_TYPE_I64) {
197-
set_rows_cuda(
198-
src0_d, (const int64_t *)src1->data, (half*)dst->data,
199-
ne00, ne01, ne02, ne03,
200-
ne10, ne11, ne12, ne13,
201-
nb01, nb02, nb03,
202-
nb10, nb11, nb12,
203-
nb1, nb2, nb3,
204-
stream
205-
);
206-
} else {
207-
set_rows_cuda(
208-
src0_d, (const int32_t *)src1->data, (half*)dst->data,
209-
ne00, ne01, ne02, ne03,
210-
ne10, ne11, ne12, ne13,
211-
nb01, nb02, nb03,
212-
nb10, nb11, nb12,
213-
nb1, nb2, nb3,
214-
stream
215-
);
216-
}
179+
set_rows_cuda(
180+
src0_d, src1_d, (half*)dst->data,
181+
ne00, ne01, ne02, ne03,
182+
ne10, ne11, ne12, ne13,
183+
nb01, nb02, nb03,
184+
nb10, nb11, nb12,
185+
nb1, nb2, nb3,
186+
stream
187+
);
217188
} else if (dst->type == GGML_TYPE_BF16) {
218-
if (src1->type == GGML_TYPE_I64) {
219-
set_rows_cuda(
220-
src0_d, (const int64_t *)src1->data, (nv_bfloat16*)dst->data,
221-
ne00, ne01, ne02, ne03,
222-
ne10, ne11, ne12, ne13,
223-
nb01, nb02, nb03,
224-
nb10, nb11, nb12,
225-
nb1, nb2, nb3,
226-
stream
227-
);
228-
} else {
229-
set_rows_cuda(
230-
src0_d, (const int32_t *)src1->data, (nv_bfloat16*)dst->data,
231-
ne00, ne01, ne02, ne03,
232-
ne10, ne11, ne12, ne13,
233-
nb01, nb02, nb03,
234-
nb10, nb11, nb12,
235-
nb1, nb2, nb3,
236-
stream
237-
);
238-
}
189+
set_rows_cuda(
190+
src0_d, src1_d, (nv_bfloat16*)dst->data,
191+
ne00, ne01, ne02, ne03,
192+
ne10, ne11, ne12, ne13,
193+
nb01, nb02, nb03,
194+
nb10, nb11, nb12,
195+
nb1, nb2, nb3,
196+
stream
197+
);
239198
} else if (dst->type == GGML_TYPE_Q4_0) {
240-
if (src1->type == GGML_TYPE_I64) {
241-
set_rows_cuda_quant<int64_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
242-
src0_d, (const int64_t *)src1->data, (block_q4_0*)dst->data,
243-
ne00, ne01, ne02, ne03,
244-
ne10, ne11, ne12, ne13,
245-
nb01, nb02, nb03,
246-
nb10, nb11, nb12,
247-
nb1, nb2, nb3,
248-
stream
249-
);
250-
} else {
251-
set_rows_cuda_quant<int32_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
252-
src0_d, (const int32_t *)src1->data, (block_q4_0*)dst->data,
253-
ne00, ne01, ne02, ne03,
254-
ne10, ne11, ne12, ne13,
255-
nb01, nb02, nb03,
256-
nb10, nb11, nb12,
257-
nb1, nb2, nb3,
258-
stream
259-
);
260-
}
199+
set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
200+
src0_d, src1_d, (block_q4_0*)dst->data,
201+
ne00, ne01, ne02, ne03,
202+
ne10, ne11, ne12, ne13,
203+
nb01, nb02, nb03,
204+
nb10, nb11, nb12,
205+
nb1, nb2, nb3,
206+
stream
207+
);
261208
} else if (dst->type == GGML_TYPE_Q4_1) {
262-
if (src1->type == GGML_TYPE_I64) {
263-
set_rows_cuda_quant<int64_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
264-
src0_d, (const int64_t *)src1->data, (block_q4_1*)dst->data,
265-
ne00, ne01, ne02, ne03,
266-
ne10, ne11, ne12, ne13,
267-
nb01, nb02, nb03,
268-
nb10, nb11, nb12,
269-
nb1, nb2, nb3,
270-
stream
271-
);
272-
} else {
273-
set_rows_cuda_quant<int32_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
274-
src0_d, (const int32_t *)src1->data, (block_q4_1*)dst->data,
275-
ne00, ne01, ne02, ne03,
276-
ne10, ne11, ne12, ne13,
277-
nb01, nb02, nb03,
278-
nb10, nb11, nb12,
279-
nb1, nb2, nb3,
280-
stream
281-
);
282-
}
209+
set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
210+
src0_d, src1_d, (block_q4_1*)dst->data,
211+
ne00, ne01, ne02, ne03,
212+
ne10, ne11, ne12, ne13,
213+
nb01, nb02, nb03,
214+
nb10, nb11, nb12,
215+
nb1, nb2, nb3,
216+
stream
217+
);
283218
} else if (dst->type == GGML_TYPE_Q5_0) {
284-
if (src1->type == GGML_TYPE_I64) {
285-
set_rows_cuda_quant<int64_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
286-
src0_d, (const int64_t *)src1->data, (block_q5_0*)dst->data,
287-
ne00, ne01, ne02, ne03,
288-
ne10, ne11, ne12, ne13,
289-
nb01, nb02, nb03,
290-
nb10, nb11, nb12,
291-
nb1, nb2, nb3,
292-
stream
293-
);
294-
} else {
295-
set_rows_cuda_quant<int32_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
296-
src0_d, (const int32_t *)src1->data, (block_q5_0*)dst->data,
297-
ne00, ne01, ne02, ne03,
298-
ne10, ne11, ne12, ne13,
299-
nb01, nb02, nb03,
300-
nb10, nb11, nb12,
301-
nb1, nb2, nb3,
302-
stream
303-
);
304-
}
219+
set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
220+
src0_d, src1_d, (block_q5_0*)dst->data,
221+
ne00, ne01, ne02, ne03,
222+
ne10, ne11, ne12, ne13,
223+
nb01, nb02, nb03,
224+
nb10, nb11, nb12,
225+
nb1, nb2, nb3,
226+
stream
227+
);
305228
} else if (dst->type == GGML_TYPE_Q5_1) {
306-
if (src1->type == GGML_TYPE_I64) {
307-
set_rows_cuda_quant<int64_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
308-
src0_d, (const int64_t *)src1->data, (block_q5_1*)dst->data,
309-
ne00, ne01, ne02, ne03,
310-
ne10, ne11, ne12, ne13,
311-
nb01, nb02, nb03,
312-
nb10, nb11, nb12,
313-
nb1, nb2, nb3,
314-
stream
315-
);
316-
} else {
317-
set_rows_cuda_quant<int32_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
318-
src0_d, (const int32_t *)src1->data, (block_q5_1*)dst->data,
319-
ne00, ne01, ne02, ne03,
320-
ne10, ne11, ne12, ne13,
321-
nb01, nb02, nb03,
322-
nb10, nb11, nb12,
323-
nb1, nb2, nb3,
324-
stream
325-
);
326-
}
229+
set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
230+
src0_d, src1_d, (block_q5_1*)dst->data,
231+
ne00, ne01, ne02, ne03,
232+
ne10, ne11, ne12, ne13,
233+
nb01, nb02, nb03,
234+
nb10, nb11, nb12,
235+
nb1, nb2, nb3,
236+
stream
237+
);
327238
} else if (dst->type == GGML_TYPE_Q8_0) {
328-
if (src1->type == GGML_TYPE_I64) {
329-
set_rows_cuda_quant<int64_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
330-
src0_d, (const int64_t *)src1->data, (block_q8_0*)dst->data,
331-
ne00, ne01, ne02, ne03,
332-
ne10, ne11, ne12, ne13,
333-
nb01, nb02, nb03,
334-
nb10, nb11, nb12,
335-
nb1, nb2, nb3,
336-
stream
337-
);
338-
} else {
339-
set_rows_cuda_quant<int32_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
340-
src0_d, (const int32_t *)src1->data, (block_q8_0*)dst->data,
341-
ne00, ne01, ne02, ne03,
342-
ne10, ne11, ne12, ne13,
343-
nb01, nb02, nb03,
344-
nb10, nb11, nb12,
345-
nb1, nb2, nb3,
346-
stream
347-
);
348-
}
239+
set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
240+
src0_d, src1_d, (block_q8_0*)dst->data,
241+
ne00, ne01, ne02, ne03,
242+
ne10, ne11, ne12, ne13,
243+
nb01, nb02, nb03,
244+
nb10, nb11, nb12,
245+
nb1, nb2, nb3,
246+
stream
247+
);
349248
} else if (dst->type == GGML_TYPE_IQ4_NL) {
350-
if (src1->type == GGML_TYPE_I64) {
351-
set_rows_cuda_quant<int64_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
352-
src0_d, (const int64_t *)src1->data, (block_iq4_nl*)dst->data,
353-
ne00, ne01, ne02, ne03,
354-
ne10, ne11, ne12, ne13,
355-
nb01, nb02, nb03,
356-
nb10, nb11, nb12,
357-
nb1, nb2, nb3,
358-
stream
359-
);
360-
} else {
361-
set_rows_cuda_quant<int32_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
362-
src0_d, (const int32_t *)src1->data, (block_iq4_nl*)dst->data,
363-
ne00, ne01, ne02, ne03,
364-
ne10, ne11, ne12, ne13,
365-
nb01, nb02, nb03,
366-
nb10, nb11, nb12,
367-
nb1, nb2, nb3,
368-
stream
369-
);
370-
}
249+
set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
250+
src0_d, src1_d, (block_iq4_nl*)dst->data,
251+
ne00, ne01, ne02, ne03,
252+
ne10, ne11, ne12, ne13,
253+
nb01, nb02, nb03,
254+
nb10, nb11, nb12,
255+
nb1, nb2, nb3,
256+
stream
257+
);
371258
} else {
372259
GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
373260
}
374261
}
262+
263+
264+
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
265+
const ggml_tensor * src0 = dst->src[0];
266+
const ggml_tensor * src1 = dst->src[1];
267+
268+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
269+
GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
270+
271+
if (src1->type == GGML_TYPE_I64) {
272+
set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);
273+
} else {
274+
set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);
275+
}
276+
}

0 commit comments

Comments
 (0)