Skip to content

Commit af23220

Browse files
committed
Use global async variable to decide path in sycl_ext_[malloc_device|free]
1 parent bc430c6 commit af23220

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3049,34 +3049,36 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
30493049
}
30503050

30513051
// Helper functions to unify device memory allocation for both async and sync paths
3052-
static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size, bool use_async) {
3052+
static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) {
3053+
bool use_async = g_ggml_sycl_use_async_mem_op;
30533054
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
30543055
if (use_async) {
30553056
return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);
30563057
}
30573058
#else
3058-
// If async allocation extension is not available, we should have never passed use_async=true
3059+
// If async allocation extension is not available, use_async should always be false.
30593060
GGML_ASSERT(!use_async);
30603061
#endif
30613062
return sycl::malloc(size, *stream, sycl::usm::alloc::device);
30623063
}
30633064

3064-
static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr, bool use_async) {
3065+
static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {
3066+
bool use_async = g_ggml_sycl_use_async_mem_op;
30653067
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
30663068
if (use_async) {
30673069
syclex::async_free(*stream, ptr);
30683070
return;
30693071
}
30703072
#else
3071-
// If async allocation extension is not available, we should have never passed use_async=true
3073+
// If async allocation extension is not available, use_async should always be false.
30723074
GGML_ASSERT(!use_async);
30733075
#endif
30743076
sycl::free(ptr, *stream);
30753077
}
30763078

30773079
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
30783080
dpct::queue_ptr stream) {
3079-
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size, g_ggml_sycl_use_async_mem_op));
3081+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
30803082

30813083
sycl::event copy_event;
30823084
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
@@ -3105,7 +3107,7 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
31053107
if (!g_ggml_sycl_use_async_mem_op) {
31063108
reorder_event.wait_and_throw();
31073109
}
3108-
sycl_ext_free(stream, tmp_buf, g_ggml_sycl_use_async_mem_op);
3110+
sycl_ext_free(stream, tmp_buf);
31093111
}
31103112

31113113
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3114,7 +3116,7 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
31143116

31153117
const int nblocks = size / sizeof(block_q4_K);
31163118

3117-
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size, g_ggml_sycl_use_async_mem_op));
3119+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
31183120

31193121
sycl::event copy_event;
31203122
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
@@ -3143,7 +3145,7 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
31433145
if (!g_ggml_sycl_use_async_mem_op) {
31443146
reorder_event.wait_and_throw();
31453147
}
3146-
sycl_ext_free(stream, tmp_buf, g_ggml_sycl_use_async_mem_op);
3148+
sycl_ext_free(stream, tmp_buf);
31473149
}
31483150

31493151
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3152,7 +3154,7 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
31523154

31533155
const int nblocks = size / sizeof(block_q6_K);
31543156

3155-
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size, g_ggml_sycl_use_async_mem_op));
3157+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
31563158

31573159
sycl::event copy_event;
31583160
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
@@ -3191,7 +3193,7 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
31913193
if (!g_ggml_sycl_use_async_mem_op) {
31923194
reorder_event.wait_and_throw();
31933195
}
3194-
sycl_ext_free(stream, tmp_buf, g_ggml_sycl_use_async_mem_op);
3196+
sycl_ext_free(stream, tmp_buf);
31953197
}
31963198

31973199
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {

0 commit comments

Comments
 (0)