Skip to content

Commit ffb2cc8

Browse files
author
Alberto Cabrera
committed
Refactored src1 copy logic in op_mul_mat
1 parent eda44a4 commit ffb2cc8

File tree

1 file changed

+22
-43
lines changed

1 file changed

+22
-43
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2394,11 +2394,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
23942394
// here an event is recorded that signals that the main device has finished calculating the input data
23952395
if (split && used_devices > 1) {
23962396
ggml_sycl_set_device(ctx.device);
2397-
/*
2398-
DPCT1024:91: The original code returned the error code that was further
2399-
consumed by the program logic. This original code was replaced with 0.
2400-
You may need to rewrite the program logic consuming the error code.
2401-
*/
24022397
SYCL_CHECK(CHECK_TRY_ERROR(
24032398
*src0_extra->events[ctx.device][0] =
24042399
ctx.stream()->ext_oneapi_submit_barrier()));
@@ -2422,11 +2417,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
24222417

24232418
// wait for main GPU data if necessary
24242419
if (split && (i != ctx.device || is != 0)) {
2425-
/*
2426-
DPCT1009:163: SYCL uses exceptions to report errors and does not
2427-
use the error codes. The original code was commented out and a
2428-
warning string was inserted. You need to rewrite this code.
2429-
*/
24302420
SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
24312421
{*src0_extra->events[ctx.device][0]})));
24322422
}
@@ -2452,38 +2442,39 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
24522442
// copy src0, src1 to device if necessary
24532443
if (src1_is_contiguous) {
24542444
if (i != ctx.device) {
2455-
if constexpr(quantize_enabled) {
2445+
if constexpr (quantize_enabled) {
24562446
char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
2457-
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
2458-
src1_ddq_i, src1_ddq_i_source,
2459-
src1_ncols * src1_padded_col_size * q8_1_ts /
2460-
q8_1_bs).wait()));
2447+
SYCL_CHECK(
2448+
CHECK_TRY_ERROR(stream
2449+
->memcpy(src1_ddq_i, src1_ddq_i_source,
2450+
src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
2451+
.wait()));
24612452
} else {
2462-
24632453
float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
2464-
src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
2454+
src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
24652455

2466-
SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream,
2467-
src1_ddf_i, src1_ddf_i_source,
2468-
src1_ncols * ne10 * sizeof(float))));
2456+
SYCL_CHECK(
2457+
CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
2458+
src1_ncols * ne10 * sizeof(float))));
24692459
}
24702460
}
2471-
} else if (src1_on_device && !src1_is_contiguous) {
2472-
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
2473-
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
24742461
} else {
2475-
GGML_ABORT("fatal error");
2476-
}
2462+
if (src1_on_device) {
2463+
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
2464+
src1_col_0 + src1_ncols, stream));
2465+
} else {
2466+
GGML_ABORT("src1 is non-contiguous and not on device");
2467+
}
24772468

2478-
if constexpr(quantize_enabled) {
2479-
if (!src1_is_contiguous) {
2469+
if constexpr (quantize_enabled) {
24802470
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
24812471
/*num_src=*/2, " : converting src1 to Q8_1");
24822472
try {
2483-
quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
2484-
} catch (sycl::exception const &exc) {
2485-
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
2486-
<< ", line:" << __LINE__ << std::endl;
2473+
quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
2474+
src1_padded_col_size, stream);
2475+
} catch (const sycl::exception & exc) {
2476+
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
2477+
<< "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
24872478
std::exit(1);
24882479
}
24892480
}
@@ -2498,12 +2489,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
24982489
// do the computation
24992490
SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
25002491
dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
2501-
/*
2502-
DPCT1010:93: SYCL uses exceptions to report errors and does not
2503-
use the error codes. The call was replaced with 0. You need to
2504-
rewrite this code.
2505-
*/
2506-
SYCL_CHECK(0);
25072492

25082493
// copy dst to host or other device if necessary
25092494
if (!dst_on_device) {
@@ -2534,12 +2519,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
25342519

25352520
// add event for the main device to wait on until other device is done
25362521
if (split && (i != ctx.device || is != 0)) {
2537-
/*
2538-
DPCT1024:94: The original code returned the error code that
2539-
was further consumed by the program logic. This original
2540-
code was replaced with 0. You may need to rewrite the
2541-
program logic consuming the error code.
2542-
*/
25432522
SYCL_CHECK(CHECK_TRY_ERROR(
25442523
*src0_extra->events[i][is] =
25452524
stream->ext_oneapi_submit_barrier()));

0 commit comments

Comments
 (0)