Skip to content

Commit 4ce5fb5

Browse files
MengAiDevMengAiDev
authored andcommitted
fix(ggml-sycl): resolve Windows null pointer bug and improve memory operations
- Add null pointer check for tensor data on Windows platform - Implement proper memset operation in buffer clear function - Add get_tensor function and include null pointer check - Improve error handling and logging in SYCL operations
1 parent 9c42e0c commit 4ce5fb5

File tree

1 file changed

+42
-4
lines changed

1 file changed

+42
-4
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,10 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
399399
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait()));
400400
free(host_buf);
401401
#else
402+
// 修复Windows平台空指针漏洞:添加与Linux平台一致的空指针检查
403+
if (tensor->data == nullptr) {
404+
GGML_ABORT("Error: Tensor data pointer is null.\n");
405+
}
402406
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait()));
403407
#endif
404408
}
@@ -430,6 +434,10 @@ catch (sycl::exception const &exc) {
430434
std::exit(1);
431435
}
432436

437+
<< ", line:" << __LINE__ << std::endl;
438+
std::exit(1);
439+
}
440+
433441
static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
434442
const void *ptr_src, size_t size) {
435443
char *host_buf = (char *)malloc(size);
@@ -511,10 +519,10 @@ static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
511519
queue_ptr stream = ctx->stream;
512520
SYCL_CHECK(
513521
CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));
514-
515-
SYCL_CHECK(CHECK_TRY_ERROR((*stream)
516-
.memset(ctx->dev_ptr, value, buffer->size)
517-
.wait()));
522+
523+
// FIX: 添加memset操作清除缓冲区
524+
SYCL_CHECK(CHECK_TRY_ERROR(
525+
(*stream).memset(ctx->dev_ptr, value, buffer->size).wait()));
518526
}
519527
catch (sycl::exception const &exc) {
520528
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -525,6 +533,8 @@ catch (sycl::exception const &exc) {
525533
static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
526534
size_t offset, size_t size) {
527535
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
536+
537+
528538
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
529539
GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value);
530540
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
@@ -3879,6 +3889,34 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
38793889
void *data, size_t offset,
38803890
size_t size) try {
38813891
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
3892+
}
3893+
3894+
static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
3895+
const ggml_tensor *tensor,
3896+
void *data, size_t offset,
3897+
size_t size) try {
3898+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
3899+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
3900+
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
3901+
ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
3902+
3903+
ggml_sycl_set_device(ctx->device);
3904+
auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue();
3905+
3906+
// 修复空指针漏洞:添加与set_tensor一致的空指针检查
3907+
if (tensor->data == nullptr) {
3908+
GGML_ABORT("Error: Tensor data pointer is null in get_tensor.\n");
3909+
}
3910+
3911+
SYCL_CHECK(CHECK_TRY_ERROR(
3912+
stream.memcpy(data, (const char *)tensor->data + offset, size)
3913+
.wait()));
3914+
}
3915+
catch (sycl::exception const &exc) {
3916+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3917+
<< ", line:" << __LINE__ << std::endl;
3918+
std::exit(1);
3919+
38823920
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
38833921
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
38843922
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;

0 commit comments

Comments
 (0)