@@ -57,7 +57,7 @@ int g_ggml_sycl_disable_optimize = 0;
5757int g_ggml_sycl_disable_graph = 0 ;
5858int g_ggml_sycl_disable_dnn = 0 ;
5959int g_ggml_sycl_prioritize_dmmv = 0 ;
60- int g_ggml_sycl_disable_async_mem_alloc = 0 ;
60+ int g_ggml_sycl_use_async_mem_op = 0 ;
6161
6262static ggml_sycl_device_info ggml_sycl_init () {
6363 ggml_sycl_device_info info = {};
@@ -244,17 +244,16 @@ static void ggml_check_sycl() try {
244244 // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
245245 // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
246246 // other places.
247- g_ggml_sycl_disable_async_mem_alloc =
248247#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
249- g_ggml_sycl_disable_graph;
250- for (unsigned int i = 0 ; i < dpct::dev_mgr::instance ().device_count () && !g_ggml_sycl_disable_async_mem_alloc;
251- ++i) {
252- if (!dpct::dev_mgr::instance ().get_device (i).has (sycl::aspect::ext_oneapi_async_memory_alloc)) {
253- g_ggml_sycl_disable_async_mem_alloc = 1 ;
248+ g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
249+ if (g_ggml_sycl_use_async_mem_op) {
250+ for (unsigned int i = 0 ; i < dpct::dev_mgr::instance ().device_count (); ++i) {
251+ if (!dpct::dev_mgr::instance ().get_device (i).has (sycl::aspect::ext_oneapi_async_memory_alloc)) {
252+ g_ggml_sycl_use_async_mem_op = 0 ;
253+ break ;
254+ }
254255 }
255256 }
256- #else
257- 1 ;
258257#endif
259258 if (CHECK_TRY_ERROR (g_all_sycl_device_count =
260259 dpct::dev_mgr::instance ().device_count ()) != 0 ) {
@@ -3050,22 +3049,19 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
30503049}
30513050
30523051// Helper functions to unify device memory allocation for both async and sync paths
3053- static inline void * sycl_malloc_opt_async (dpct::queue_ptr stream,
3054- sycl::usm::alloc alloc_type,
3055- size_t size,
3056- bool use_async) {
3052+ static inline void * sycl_ext_malloc_device (dpct::queue_ptr stream, size_t size, bool use_async) {
30573053#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
30583054 if (use_async) {
3059- return syclex::async_malloc (*stream, alloc_type , size);
3055+ return syclex::async_malloc (*stream, sycl::usm::alloc::device , size);
30603056 }
30613057#else
30623058 // If async allocation extension is not available, we should have never passed use_async=true
30633059 GGML_ASSERT (!use_async);
30643060#endif
3065- return sycl::malloc (size, *stream, alloc_type );
3061+ return sycl::malloc (size, *stream, sycl::usm::alloc::device );
30663062}
30673063
3068- static inline void sycl_free_opt_async (dpct::queue_ptr stream, void * ptr, bool use_async) {
3064+ static inline void sycl_ext_free (dpct::queue_ptr stream, void * ptr, bool use_async) {
30693065#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
30703066 if (use_async) {
30713067 syclex::async_free (*stream, ptr);
@@ -3080,13 +3076,11 @@ static inline void sycl_free_opt_async(dpct::queue_ptr stream, void * ptr, bool
30803076
30813077static void reorder_qw_q4_0 (uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
30823078 dpct::queue_ptr stream) {
3083- const bool use_async = !g_ggml_sycl_disable_async_mem_alloc;
3084- uint8_t * tmp_buf =
3085- static_cast <uint8_t *>(sycl_malloc_opt_async (stream, sycl::usm::alloc::device, size, use_async));
3079+ uint8_t * tmp_buf = static_cast <uint8_t *>(sycl_ext_malloc_device (stream, size, g_ggml_sycl_use_async_mem_op));
30863080
30873081 sycl::event copy_event;
30883082 SYCL_CHECK (CHECK_TRY_ERROR (copy_event = stream->memcpy (tmp_buf, data_device, size)));
3089- if (!use_async ) {
3083+ if (!g_ggml_sycl_use_async_mem_op ) {
30903084 copy_event.wait ();
30913085 }
30923086
@@ -3108,10 +3102,10 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
31083102 }
31093103 *(d_ptr + ib) = x[ib].d ;
31103104 });
3111- if (!use_async ) {
3105+ if (!g_ggml_sycl_use_async_mem_op ) {
31123106 reorder_event.wait_and_throw ();
31133107 }
3114- sycl_free_opt_async (stream, tmp_buf, use_async );
3108+ sycl_ext_free (stream, tmp_buf, g_ggml_sycl_use_async_mem_op );
31153109}
31163110
31173111static void reorder_qw_q4_k (uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3120,13 +3114,11 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
31203114
31213115 const int nblocks = size / sizeof (block_q4_K);
31223116
3123- const bool use_async = !g_ggml_sycl_disable_async_mem_alloc;
3124- uint8_t * tmp_buf =
3125- static_cast <uint8_t *>(sycl_malloc_opt_async (stream, sycl::usm::alloc::device, size, use_async));
3117+ uint8_t * tmp_buf = static_cast <uint8_t *>(sycl_ext_malloc_device (stream, size, g_ggml_sycl_use_async_mem_op));
31263118
31273119 sycl::event copy_event;
31283120 SYCL_CHECK (CHECK_TRY_ERROR (copy_event = stream->memcpy (tmp_buf, data_device, size)));
3129- if (!use_async ) {
3121+ if (!g_ggml_sycl_use_async_mem_op ) {
31303122 copy_event.wait ();
31313123 }
31323124
@@ -3148,10 +3140,10 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
31483140
31493141 dm_ptr[ib] = x[ib].dm ;
31503142 });
3151- if (!use_async ) {
3143+ if (!g_ggml_sycl_use_async_mem_op ) {
31523144 reorder_event.wait_and_throw ();
31533145 }
3154- sycl_free_opt_async (stream, tmp_buf, use_async );
3146+ sycl_ext_free (stream, tmp_buf, g_ggml_sycl_use_async_mem_op );
31553147}
31563148
31573149static void reorder_qw_q6_k (uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3160,13 +3152,11 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
31603152
31613153 const int nblocks = size / sizeof (block_q6_K);
31623154
3163- const bool use_async = !g_ggml_sycl_disable_async_mem_alloc;
3164- uint8_t * tmp_buf =
3165- static_cast <uint8_t *>(sycl_malloc_opt_async (stream, sycl::usm::alloc::device, size, use_async));
3155+ uint8_t * tmp_buf = static_cast <uint8_t *>(sycl_ext_malloc_device (stream, size, g_ggml_sycl_use_async_mem_op));
31663156
31673157 sycl::event copy_event;
31683158 SYCL_CHECK (CHECK_TRY_ERROR (copy_event = stream->memcpy (tmp_buf, data_device, size)));
3169- if (!use_async ) {
3159+ if (!g_ggml_sycl_use_async_mem_op ) {
31703160 copy_event.wait ();
31713161 }
31723162
@@ -3175,34 +3165,33 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
31753165 auto * scales_ptr = qh_ptr + (QK_K / 4 ) * nblocks;
31763166 sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16 ) * nblocks);
31773167
3178- auto reorder_event = stream->parallel_for (nblocks,
3179- [=](auto i) {
3180- const block_q6_K * x = (const block_q6_K *) tmp_buf;
3181- const int ib = i;
3182-
3183- const uint8_t * ql = x[ib].ql ;
3184- const uint8_t * qh = x[ib].qh ;
3185- uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2 ) * ib;
3186- uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4 ) * ib;
3187- uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16 ) * ib;
3188-
3189- for (int j = 0 ; j < QK_K / 2 ; ++j) {
3190- base_ql_ptr[j] = ql[j];
3191- }
3192- for (int j = 0 ; j < QK_K / 4 ; ++j) {
3193- base_qh_ptr[j] = qh[j];
3194- }
3195-
3196- for (int j = 0 ; j < QK_K / 16 ; ++j) {
3197- base_scales_ptr[j] = x[ib].scales [j];
3198- }
3199-
3200- dm_ptr[ib] = x[ib].d ;
3201- });
3202- if (!use_async) {
3168+ auto reorder_event = stream->parallel_for (nblocks, [=](auto i) {
3169+ const block_q6_K * x = (const block_q6_K *) tmp_buf;
3170+ const int ib = i;
3171+
3172+ const uint8_t * ql = x[ib].ql ;
3173+ const uint8_t * qh = x[ib].qh ;
3174+ uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2 ) * ib;
3175+ uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4 ) * ib;
3176+ uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16 ) * ib;
3177+
3178+ for (int j = 0 ; j < QK_K / 2 ; ++j) {
3179+ base_ql_ptr[j] = ql[j];
3180+ }
3181+ for (int j = 0 ; j < QK_K / 4 ; ++j) {
3182+ base_qh_ptr[j] = qh[j];
3183+ }
3184+
3185+ for (int j = 0 ; j < QK_K / 16 ; ++j) {
3186+ base_scales_ptr[j] = x[ib].scales [j];
3187+ }
3188+
3189+ dm_ptr[ib] = x[ib].d ;
3190+ });
3191+ if (!g_ggml_sycl_use_async_mem_op) {
32033192 reorder_event.wait_and_throw ();
32043193 }
3205- sycl_free_opt_async (stream, tmp_buf, use_async );
3194+ sycl_ext_free (stream, tmp_buf, g_ggml_sycl_use_async_mem_op );
32063195}
32073196
32083197static void reorder_qw (const ggml_tensor * src0, dpct::queue_ptr stream) {
@@ -4116,7 +4105,7 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
41164105 // We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
41174106 // as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
41184107 // in reordering.
4119- if (g_ggml_sycl_disable_async_mem_alloc ) {
4108+ if (!g_ggml_sycl_use_async_mem_op ) {
41204109 GGML_LOG_INFO (
41214110 " %s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
41224111 " oneAPI async memory allocation extension "
0 commit comments