Skip to content

Commit bc430c6

Browse files
committed
Address reviewer feedback
1 parent f4e1782 commit bc430c6

File tree

1 file changed

+48
-59
lines changed

1 file changed

+48
-59
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 48 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ int g_ggml_sycl_disable_optimize = 0;
5757
int g_ggml_sycl_disable_graph = 0;
5858
int g_ggml_sycl_disable_dnn = 0;
5959
int g_ggml_sycl_prioritize_dmmv = 0;
60-
int g_ggml_sycl_disable_async_mem_alloc = 0;
60+
int g_ggml_sycl_use_async_mem_op = 0;
6161

6262
static ggml_sycl_device_info ggml_sycl_init() {
6363
ggml_sycl_device_info info = {};
@@ -244,17 +244,16 @@ static void ggml_check_sycl() try {
244244
// Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
245245
// properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
246246
// other places.
247-
g_ggml_sycl_disable_async_mem_alloc =
248247
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
249-
g_ggml_sycl_disable_graph;
250-
for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count() && !g_ggml_sycl_disable_async_mem_alloc;
251-
++i) {
252-
if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {
253-
g_ggml_sycl_disable_async_mem_alloc = 1;
248+
g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
249+
if (g_ggml_sycl_use_async_mem_op) {
250+
for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) {
251+
if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {
252+
g_ggml_sycl_use_async_mem_op = 0;
253+
break;
254+
}
254255
}
255256
}
256-
#else
257-
1;
258257
#endif
259258
if (CHECK_TRY_ERROR(g_all_sycl_device_count =
260259
dpct::dev_mgr::instance().device_count()) != 0) {
@@ -3050,22 +3049,19 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
30503049
}
30513050

30523051
// Helper functions to unify device memory allocation for both async and sync paths
3053-
static inline void * sycl_malloc_opt_async(dpct::queue_ptr stream,
3054-
sycl::usm::alloc alloc_type,
3055-
size_t size,
3056-
bool use_async) {
3052+
static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size, bool use_async) {
30573053
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
30583054
if (use_async) {
3059-
return syclex::async_malloc(*stream, alloc_type, size);
3055+
return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);
30603056
}
30613057
#else
30623058
// If async allocation extension is not available, we should have never passed use_async=true
30633059
GGML_ASSERT(!use_async);
30643060
#endif
3065-
return sycl::malloc(size, *stream, alloc_type);
3061+
return sycl::malloc(size, *stream, sycl::usm::alloc::device);
30663062
}
30673063

3068-
static inline void sycl_free_opt_async(dpct::queue_ptr stream, void * ptr, bool use_async) {
3064+
static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr, bool use_async) {
30693065
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
30703066
if (use_async) {
30713067
syclex::async_free(*stream, ptr);
@@ -3080,13 +3076,11 @@ static inline void sycl_free_opt_async(dpct::queue_ptr stream, void * ptr, bool
30803076

30813077
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
30823078
dpct::queue_ptr stream) {
3083-
const bool use_async = !g_ggml_sycl_disable_async_mem_alloc;
3084-
uint8_t * tmp_buf =
3085-
static_cast<uint8_t *>(sycl_malloc_opt_async(stream, sycl::usm::alloc::device, size, use_async));
3079+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size, g_ggml_sycl_use_async_mem_op));
30863080

30873081
sycl::event copy_event;
30883082
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3089-
if (!use_async) {
3083+
if (!g_ggml_sycl_use_async_mem_op) {
30903084
copy_event.wait();
30913085
}
30923086

@@ -3108,10 +3102,10 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
31083102
}
31093103
*(d_ptr + ib) = x[ib].d;
31103104
});
3111-
if (!use_async) {
3105+
if (!g_ggml_sycl_use_async_mem_op) {
31123106
reorder_event.wait_and_throw();
31133107
}
3114-
sycl_free_opt_async(stream, tmp_buf, use_async);
3108+
sycl_ext_free(stream, tmp_buf, g_ggml_sycl_use_async_mem_op);
31153109
}
31163110

31173111
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3120,13 +3114,11 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
31203114

31213115
const int nblocks = size / sizeof(block_q4_K);
31223116

3123-
const bool use_async = !g_ggml_sycl_disable_async_mem_alloc;
3124-
uint8_t * tmp_buf =
3125-
static_cast<uint8_t *>(sycl_malloc_opt_async(stream, sycl::usm::alloc::device, size, use_async));
3117+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size, g_ggml_sycl_use_async_mem_op));
31263118

31273119
sycl::event copy_event;
31283120
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3129-
if (!use_async) {
3121+
if (!g_ggml_sycl_use_async_mem_op) {
31303122
copy_event.wait();
31313123
}
31323124

@@ -3148,10 +3140,10 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
31483140

31493141
dm_ptr[ib] = x[ib].dm;
31503142
});
3151-
if (!use_async) {
3143+
if (!g_ggml_sycl_use_async_mem_op) {
31523144
reorder_event.wait_and_throw();
31533145
}
3154-
sycl_free_opt_async(stream, tmp_buf, use_async);
3146+
sycl_ext_free(stream, tmp_buf, g_ggml_sycl_use_async_mem_op);
31553147
}
31563148

31573149
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3160,13 +3152,11 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
31603152

31613153
const int nblocks = size / sizeof(block_q6_K);
31623154

3163-
const bool use_async = !g_ggml_sycl_disable_async_mem_alloc;
3164-
uint8_t * tmp_buf =
3165-
static_cast<uint8_t *>(sycl_malloc_opt_async(stream, sycl::usm::alloc::device, size, use_async));
3155+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size, g_ggml_sycl_use_async_mem_op));
31663156

31673157
sycl::event copy_event;
31683158
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3169-
if (!use_async) {
3159+
if (!g_ggml_sycl_use_async_mem_op) {
31703160
copy_event.wait();
31713161
}
31723162

@@ -3175,34 +3165,33 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
31753165
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
31763166
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
31773167

3178-
auto reorder_event = stream->parallel_for(nblocks,
3179-
[=](auto i) {
3180-
const block_q6_K * x = (const block_q6_K *) tmp_buf;
3181-
const int ib = i;
3182-
3183-
const uint8_t * ql = x[ib].ql;
3184-
const uint8_t * qh = x[ib].qh;
3185-
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3186-
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3187-
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3188-
3189-
for (int j = 0; j < QK_K / 2; ++j) {
3190-
base_ql_ptr[j] = ql[j];
3191-
}
3192-
for (int j = 0; j < QK_K / 4; ++j) {
3193-
base_qh_ptr[j] = qh[j];
3194-
}
3195-
3196-
for (int j = 0; j < QK_K / 16; ++j) {
3197-
base_scales_ptr[j] = x[ib].scales[j];
3198-
}
3199-
3200-
dm_ptr[ib] = x[ib].d;
3201-
});
3202-
if (!use_async) {
3168+
auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3169+
const block_q6_K * x = (const block_q6_K *) tmp_buf;
3170+
const int ib = i;
3171+
3172+
const uint8_t * ql = x[ib].ql;
3173+
const uint8_t * qh = x[ib].qh;
3174+
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3175+
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3176+
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3177+
3178+
for (int j = 0; j < QK_K / 2; ++j) {
3179+
base_ql_ptr[j] = ql[j];
3180+
}
3181+
for (int j = 0; j < QK_K / 4; ++j) {
3182+
base_qh_ptr[j] = qh[j];
3183+
}
3184+
3185+
for (int j = 0; j < QK_K / 16; ++j) {
3186+
base_scales_ptr[j] = x[ib].scales[j];
3187+
}
3188+
3189+
dm_ptr[ib] = x[ib].d;
3190+
});
3191+
if (!g_ggml_sycl_use_async_mem_op) {
32033192
reorder_event.wait_and_throw();
32043193
}
3205-
sycl_free_opt_async(stream, tmp_buf, use_async);
3194+
sycl_ext_free(stream, tmp_buf, g_ggml_sycl_use_async_mem_op);
32063195
}
32073196

32083197
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
@@ -4116,7 +4105,7 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
41164105
// We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
41174106
// as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
41184107
// in reordering.
4119-
if (g_ggml_sycl_disable_async_mem_alloc) {
4108+
if (!g_ggml_sycl_use_async_mem_op) {
41204109
GGML_LOG_INFO(
41214110
"%s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
41224111
"oneAPI async memory allocation extension "

0 commit comments

Comments
 (0)