3030#include < regex>
3131
3232#include < sycl/sycl.hpp>
33+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
34+ # include < sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
35+ #endif
3336#include < sycl/half_type.hpp>
3437
3538#include " ggml-sycl.h"
@@ -54,6 +57,7 @@ int g_ggml_sycl_disable_optimize = 0;
5457int g_ggml_sycl_disable_graph = 0 ;
5558int g_ggml_sycl_disable_dnn = 0 ;
5659int g_ggml_sycl_prioritize_dmmv = 0 ;
60+ int g_ggml_sycl_use_async_mem_op = 0 ;
5761
5862static ggml_sycl_device_info ggml_sycl_init () {
5963 ggml_sycl_device_info info = {};
@@ -237,7 +241,20 @@ static void ggml_check_sycl() try {
237241 fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
238242#endif
239243*/
240-
244+ // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
245+ // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
246+ // other places.
247+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
248+ g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
249+ if (g_ggml_sycl_use_async_mem_op) {
250+ for (unsigned int i = 0 ; i < dpct::dev_mgr::instance ().device_count (); ++i) {
251+ if (!dpct::dev_mgr::instance ().get_device (i).has (sycl::aspect::ext_oneapi_async_memory_alloc)) {
252+ g_ggml_sycl_use_async_mem_op = 0 ;
253+ break ;
254+ }
255+ }
256+ }
257+ #endif
241258 if (CHECK_TRY_ERROR (g_all_sycl_device_count =
242259 dpct::dev_mgr::instance ().device_count ()) != 0 ) {
243260 initialized = true ;
@@ -3031,19 +3048,51 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
30313048 }
30323049}
30333050
3051+ // Helper functions to unify device memory allocation for both async and sync paths
3052+ static inline void * sycl_ext_malloc_device (dpct::queue_ptr stream, size_t size) {
3053+ bool use_async = g_ggml_sycl_use_async_mem_op;
3054+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3055+ if (use_async) {
3056+ return syclex::async_malloc (*stream, sycl::usm::alloc::device, size);
3057+ }
3058+ #else
3059+ // If async allocation extension is not available, use_async should always be false.
3060+ GGML_ASSERT (!use_async);
3061+ #endif
3062+ return sycl::malloc (size, *stream, sycl::usm::alloc::device);
3063+ }
3064+
3065+ static inline void sycl_ext_free (dpct::queue_ptr stream, void * ptr) {
3066+ bool use_async = g_ggml_sycl_use_async_mem_op;
3067+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3068+ if (use_async) {
3069+ syclex::async_free (*stream, ptr);
3070+ return ;
3071+ }
3072+ #else
3073+ // If async allocation extension is not available, use_async should always be false.
3074+ GGML_ASSERT (!use_async);
3075+ #endif
3076+ sycl::free (ptr, *stream);
3077+ }
3078+
30343079static void reorder_qw_q4_0 (uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
30353080 dpct::queue_ptr stream) {
3036- auto * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
3037- SYCL_CHECK (
3038- CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size)
3039- .wait ()));
3081+ uint8_t * tmp_buf = static_cast <uint8_t *>(sycl_ext_malloc_device (stream, size));
3082+
3083+ sycl::event copy_event;
3084+ SYCL_CHECK (CHECK_TRY_ERROR (copy_event = stream->memcpy (tmp_buf, data_device, size)));
3085+ if (!g_ggml_sycl_use_async_mem_op) {
3086+ copy_event.wait ();
3087+ }
3088+
30403089 GGML_ASSERT ((size % sizeof (block_q4_0) == 0 ));
30413090 GGML_ASSERT ((offset % sizeof (block_q4_0) == 0 ));
30423091 int offset_blks = offset / sizeof (block_q4_0);
30433092 auto qs_ptr = data_device + offset_blks * QK4_0 / 2 ;
30443093 auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2 ) + offset_blks;
30453094
3046- stream->parallel_for (
3095+ auto reorder_event = stream->parallel_for (
30473096 size / sizeof (block_q4_0),
30483097 [=](auto i) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
30493098 const block_q4_0* x = (const block_q4_0*)tmp_buf;
@@ -3054,9 +3103,11 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
30543103 *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs [j];
30553104 }
30563105 *(d_ptr + ib) = x[ib].d ;
3057- }).wait_and_throw ();
3058-
3059- sycl::free (tmp_buf, *stream);
3106+ });
3107+ if (!g_ggml_sycl_use_async_mem_op) {
3108+ reorder_event.wait_and_throw ();
3109+ }
3110+ sycl_ext_free (stream, tmp_buf);
30603111}
30613112
30623113static void reorder_qw_q4_k (uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3065,14 +3116,19 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
30653116
30663117 const int nblocks = size / sizeof (block_q4_K);
30673118
3068- auto * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
3069- SYCL_CHECK (CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size).wait ()));
3119+ uint8_t * tmp_buf = static_cast <uint8_t *>(sycl_ext_malloc_device (stream, size));
3120+
3121+ sycl::event copy_event;
3122+ SYCL_CHECK (CHECK_TRY_ERROR (copy_event = stream->memcpy (tmp_buf, data_device, size)));
3123+ if (!g_ggml_sycl_use_async_mem_op) {
3124+ copy_event.wait ();
3125+ }
30703126
30713127 auto * qs_ptr = data_device;
30723128 auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
30733129 auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
30743130
3075- stream->parallel_for (nblocks, [=](auto i) {
3131+ auto reorder_event = stream->parallel_for (nblocks, [=](auto i) {
30763132 const block_q4_K * x = (const block_q4_K *) tmp_buf;
30773133 const int ib = i;
30783134
@@ -3085,9 +3141,11 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
30853141 }
30863142
30873143 dm_ptr[ib] = x[ib].dm ;
3088- }).wait_and_throw ();
3089-
3090- sycl::free (tmp_buf, *stream);
3144+ });
3145+ if (!g_ggml_sycl_use_async_mem_op) {
3146+ reorder_event.wait_and_throw ();
3147+ }
3148+ sycl_ext_free (stream, tmp_buf);
30913149}
30923150
30933151static void reorder_qw_q6_k (uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3096,42 +3154,46 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
30963154
30973155 const int nblocks = size / sizeof (block_q6_K);
30983156
3099- auto * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
3100- SYCL_CHECK (CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size).wait ()));
3157+ uint8_t * tmp_buf = static_cast <uint8_t *>(sycl_ext_malloc_device (stream, size));
3158+
3159+ sycl::event copy_event;
3160+ SYCL_CHECK (CHECK_TRY_ERROR (copy_event = stream->memcpy (tmp_buf, data_device, size)));
3161+ if (!g_ggml_sycl_use_async_mem_op) {
3162+ copy_event.wait ();
3163+ }
31013164
31023165 auto * ql_ptr = data_device;
31033166 auto * qh_ptr = ql_ptr + (QK_K / 2 ) * nblocks;
31043167 auto * scales_ptr = qh_ptr + (QK_K / 4 ) * nblocks;
31053168 sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16 ) * nblocks);
31063169
3107- stream
3108- ->parallel_for (nblocks,
3109- [=](auto i) {
3110- const block_q6_K * x = (const block_q6_K *) tmp_buf;
3111- const int ib = i;
3112-
3113- const uint8_t * ql = x[ib].ql ;
3114- const uint8_t * qh = x[ib].qh ;
3115- uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2 ) * ib;
3116- uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4 ) * ib;
3117- uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16 ) * ib;
3170+ auto reorder_event = stream->parallel_for (nblocks, [=](auto i) {
3171+ const block_q6_K * x = (const block_q6_K *) tmp_buf;
3172+ const int ib = i;
31183173
3119- for (int j = 0 ; j < QK_K / 2 ; ++j) {
3120- base_ql_ptr[j] = ql[j];
3121- }
3122- for (int j = 0 ; j < QK_K / 4 ; ++j) {
3123- base_qh_ptr[j] = qh[j];
3124- }
3174+ const uint8_t * ql = x[ib].ql ;
3175+ const uint8_t * qh = x[ib].qh ;
3176+ uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2 ) * ib;
3177+ uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4 ) * ib;
3178+ uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16 ) * ib;
31253179
3126- for (int j = 0 ; j < QK_K / 16 ; ++j) {
3127- base_scales_ptr[j] = x[ib].scales [j];
3128- }
3180+ for (int j = 0 ; j < QK_K / 2 ; ++j) {
3181+ base_ql_ptr[j] = ql[j];
3182+ }
3183+ for (int j = 0 ; j < QK_K / 4 ; ++j) {
3184+ base_qh_ptr[j] = qh[j];
3185+ }
31293186
3130- dm_ptr[ib] = x[ib]. d ;
3131- })
3132- . wait_and_throw ();
3187+ for ( int j = 0 ; j < QK_K / 16 ; ++j) {
3188+ base_scales_ptr[j] = x[ib]. scales [j];
3189+ }
31333190
3134- sycl::free (tmp_buf, *stream);
3191+ dm_ptr[ib] = x[ib].d ;
3192+ });
3193+ if (!g_ggml_sycl_use_async_mem_op) {
3194+ reorder_event.wait_and_throw ();
3195+ }
3196+ sycl_ext_free (stream, tmp_buf);
31353197}
31363198
31373199static void reorder_qw (const ggml_tensor * src0, dpct::queue_ptr stream) {
@@ -4056,6 +4118,18 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
40564118 GGML_LOG_INFO (" %s: disabling SYCL graphs due to unsupported node type %s\n " , __func__,
40574119 ggml_op_name (node_op));
40584120 return false ;
4121+ case GGML_OP_MUL_MAT:
4122+ // We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
4123+ // as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
4124+ // in reordering.
4125+ if (!g_ggml_sycl_use_async_mem_op) {
4126+ GGML_LOG_INFO (
4127+ " %s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
4128+ " oneAPI async memory allocation extension "
4129+ " %s\n " ,
4130+ __func__, ggml_op_name (node_op));
4131+ return false ;
4132+ }
40594133 }
40604134 }
40614135 return true ;
0 commit comments