3030#include  < regex> 
3131
3232#include  < sycl/sycl.hpp> 
33+ #if  defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
34+ #    include  < sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp> 
35+ #endif 
3336#include  < sycl/half_type.hpp> 
3437
3538#include  " ggml-sycl.h" 
@@ -54,6 +57,7 @@ int g_ggml_sycl_disable_optimize = 0;
5457int  g_ggml_sycl_disable_graph = 0 ;
5558int  g_ggml_sycl_disable_dnn = 0 ;
5659int  g_ggml_sycl_prioritize_dmmv = 0 ;
60+ int  g_ggml_sycl_use_async_mem_op = 0 ;
5761
5862static  ggml_sycl_device_info ggml_sycl_init () {
5963    ggml_sycl_device_info info = {};
@@ -237,7 +241,20 @@ static void ggml_check_sycl() try {
237241        fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); 
238242#endif 
239243*/ 
240- 
244+         //  Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
245+         //  properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
246+         //  other places.
247+ #if  defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
248+         g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
249+         if  (g_ggml_sycl_use_async_mem_op) {
250+             for  (unsigned  int  i = 0 ; i < dpct::dev_mgr::instance ().device_count (); ++i) {
251+                 if  (!dpct::dev_mgr::instance ().get_device (i).has (sycl::aspect::ext_oneapi_async_memory_alloc)) {
252+                     g_ggml_sycl_use_async_mem_op = 0 ;
253+                     break ;
254+                 }
255+             }
256+         }
257+ #endif 
241258        if  (CHECK_TRY_ERROR (g_all_sycl_device_count =
242259                            dpct::dev_mgr::instance ().device_count ()) != 0 ) {
243260            initialized = true ;
@@ -3031,19 +3048,51 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
30313048    }
30323049}
30333050
3051+ //  Helper functions to unify device memory allocation for both async and sync paths
3052+ static  inline  void  * sycl_ext_malloc_device (dpct::queue_ptr stream, size_t  size) {
3053+     bool  use_async = g_ggml_sycl_use_async_mem_op;
3054+ #if  defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3055+     if  (use_async) {
3056+         return  syclex::async_malloc (*stream, sycl::usm::alloc::device, size);
3057+     }
3058+ #else 
3059+     //  If async allocation extension is not available, use_async should always be false.
3060+     GGML_ASSERT (!use_async);
3061+ #endif 
3062+     return  sycl::malloc (size, *stream, sycl::usm::alloc::device);
3063+ }
3064+ 
3065+ static  inline  void  sycl_ext_free (dpct::queue_ptr stream, void  * ptr) {
3066+     bool  use_async = g_ggml_sycl_use_async_mem_op;
3067+ #if  defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3068+     if  (use_async) {
3069+         syclex::async_free (*stream, ptr);
3070+         return ;
3071+     }
3072+ #else 
3073+     //  If async allocation extension is not available, use_async should always be false.
3074+     GGML_ASSERT (!use_async);
3075+ #endif 
3076+     sycl::free (ptr, *stream);
3077+ }
3078+ 
30343079static  void  reorder_qw_q4_0 (uint8_t  * data_device, const  int  ncols, const  int  nrows, size_t  size, size_t  offset,
30353080                            dpct::queue_ptr stream) {
3036-     auto  * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
3037-     SYCL_CHECK (
3038-         CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size)
3039-             .wait ()));
3081+     uint8_t  * tmp_buf = static_cast <uint8_t  *>(sycl_ext_malloc_device (stream, size));
3082+ 
3083+     sycl::event copy_event;
3084+     SYCL_CHECK (CHECK_TRY_ERROR (copy_event = stream->memcpy (tmp_buf, data_device, size)));
3085+     if  (!g_ggml_sycl_use_async_mem_op) {
3086+         copy_event.wait ();
3087+     }
3088+ 
30403089    GGML_ASSERT ((size % sizeof (block_q4_0) == 0 ));
30413090    GGML_ASSERT ((offset % sizeof (block_q4_0) == 0 ));
30423091    int  offset_blks = offset / sizeof (block_q4_0);
30433092    auto  qs_ptr      = data_device + offset_blks * QK4_0 / 2 ;
30443093    auto  d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2 ) + offset_blks;
30453094
3046-     stream->parallel_for (
3095+     auto  reorder_event =  stream->parallel_for (
30473096        size / sizeof (block_q4_0),
30483097            [=](auto  i) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
30493098            const  block_q4_0* x = (const  block_q4_0*)tmp_buf;
@@ -3054,9 +3103,11 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
30543103                *(qs_ptr + ib * QK4_0 / 2  + j) = x[ib].qs [j];
30553104            }
30563105            *(d_ptr + ib) = x[ib].d ;
3057-         }).wait_and_throw ();
3058- 
3059-     sycl::free (tmp_buf, *stream);
3106+         });
3107+     if  (!g_ggml_sycl_use_async_mem_op) {
3108+         reorder_event.wait_and_throw ();
3109+     }
3110+     sycl_ext_free (stream, tmp_buf);
30603111}
30613112
30623113static  void  reorder_qw_q4_k (uint8_t  * data_device, size_t  size, size_t  offset, dpct::queue_ptr stream) {
@@ -3065,14 +3116,19 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
30653116
30663117    const  int  nblocks = size / sizeof (block_q4_K);
30673118
3068-     auto  * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
3069-     SYCL_CHECK (CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size).wait ()));
3119+     uint8_t  * tmp_buf = static_cast <uint8_t  *>(sycl_ext_malloc_device (stream, size));
3120+ 
3121+     sycl::event copy_event;
3122+     SYCL_CHECK (CHECK_TRY_ERROR (copy_event = stream->memcpy (tmp_buf, data_device, size)));
3123+     if  (!g_ggml_sycl_use_async_mem_op) {
3124+         copy_event.wait ();
3125+     }
30703126
30713127    auto  * qs_ptr     = data_device;
30723128    auto  * scales_ptr = qs_ptr + QK_K / 2  * nblocks;
30733129    auto  * dm_ptr     = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
30743130
3075-     stream->parallel_for (nblocks, [=](auto  i) {
3131+     auto  reorder_event =  stream->parallel_for (nblocks, [=](auto  i) {
30763132        const  block_q4_K * x  = (const  block_q4_K *) tmp_buf;
30773133        const  int           ib = i;
30783134
@@ -3085,9 +3141,11 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
30853141        }
30863142
30873143        dm_ptr[ib] = x[ib].dm ;
3088-     }).wait_and_throw ();
3089- 
3090-     sycl::free (tmp_buf, *stream);
3144+     });
3145+     if  (!g_ggml_sycl_use_async_mem_op) {
3146+         reorder_event.wait_and_throw ();
3147+     }
3148+     sycl_ext_free (stream, tmp_buf);
30913149}
30923150
30933151static  void  reorder_qw_q6_k (uint8_t  * data_device, size_t  size, size_t  offset, dpct::queue_ptr stream) {
@@ -3096,42 +3154,46 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
30963154
30973155    const  int  nblocks = size / sizeof (block_q6_K);
30983156
3099-     auto  * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
3100-     SYCL_CHECK (CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size).wait ()));
3157+     uint8_t  * tmp_buf = static_cast <uint8_t  *>(sycl_ext_malloc_device (stream, size));
3158+ 
3159+     sycl::event copy_event;
3160+     SYCL_CHECK (CHECK_TRY_ERROR (copy_event = stream->memcpy (tmp_buf, data_device, size)));
3161+     if  (!g_ggml_sycl_use_async_mem_op) {
3162+         copy_event.wait ();
3163+     }
31013164
31023165    auto  *       ql_ptr     = data_device;
31033166    auto  *       qh_ptr     = ql_ptr + (QK_K / 2 ) * nblocks;
31043167    auto  *       scales_ptr = qh_ptr + (QK_K / 4 ) * nblocks;
31053168    sycl::half * dm_ptr     = (sycl::half *) (scales_ptr + (QK_K / 16 ) * nblocks);
31063169
3107-     stream
3108-         ->parallel_for (nblocks,
3109-                        [=](auto  i) {
3110-                            const  block_q6_K * x  = (const  block_q6_K *) tmp_buf;
3111-                            const  int           ib = i;
3112- 
3113-                            const  uint8_t  * ql              = x[ib].ql ;
3114-                            const  uint8_t  * qh              = x[ib].qh ;
3115-                            uint8_t  *       base_ql_ptr     = ql_ptr + (QK_K / 2 ) * ib;
3116-                            uint8_t  *       base_qh_ptr     = qh_ptr + (QK_K / 4 ) * ib;
3117-                            uint8_t  *       base_scales_ptr = scales_ptr + (QK_K / 16 ) * ib;
3170+     auto  reorder_event = stream->parallel_for (nblocks, [=](auto  i) {
3171+         const  block_q6_K * x  = (const  block_q6_K *) tmp_buf;
3172+         const  int           ib = i;
31183173
3119-                            for  (int  j = 0 ; j < QK_K / 2 ; ++j) {
3120-                                base_ql_ptr[j] = ql[j];
3121-                            }
3122-                            for  (int  j = 0 ; j < QK_K / 4 ; ++j) {
3123-                                base_qh_ptr[j] = qh[j];
3124-                            }
3174+         const  uint8_t  * ql              = x[ib].ql ;
3175+         const  uint8_t  * qh              = x[ib].qh ;
3176+         uint8_t  *       base_ql_ptr     = ql_ptr + (QK_K / 2 ) * ib;
3177+         uint8_t  *       base_qh_ptr     = qh_ptr + (QK_K / 4 ) * ib;
3178+         uint8_t  *       base_scales_ptr = scales_ptr + (QK_K / 16 ) * ib;
31253179
3126-                            for  (int  j = 0 ; j < QK_K / 16 ; ++j) {
3127-                                base_scales_ptr[j] = x[ib].scales [j];
3128-                            }
3180+         for  (int  j = 0 ; j < QK_K / 2 ; ++j) {
3181+             base_ql_ptr[j] = ql[j];
3182+         }
3183+         for  (int  j = 0 ; j < QK_K / 4 ; ++j) {
3184+             base_qh_ptr[j] = qh[j];
3185+         }
31293186
3130-                            dm_ptr[ib] = x[ib]. d ; 
3131-                        }) 
3132-         . wait_and_throw (); 
3187+         for  ( int  j =  0 ; j < QK_K /  16 ; ++j) { 
3188+             base_scales_ptr[j] = x[ib]. scales [j]; 
3189+         } 
31333190
3134-     sycl::free (tmp_buf, *stream);
3191+         dm_ptr[ib] = x[ib].d ;
3192+     });
3193+     if  (!g_ggml_sycl_use_async_mem_op) {
3194+         reorder_event.wait_and_throw ();
3195+     }
3196+     sycl_ext_free (stream, tmp_buf);
31353197}
31363198
31373199static  void  reorder_qw (const  ggml_tensor * src0, dpct::queue_ptr stream) {
@@ -4056,6 +4118,18 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
40564118                GGML_LOG_INFO (" %s: disabling SYCL graphs due to unsupported node type %s\n " 
40574119                              ggml_op_name (node_op));
40584120                return  false ;
4121+             case  GGML_OP_MUL_MAT:
4122+                 //  We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
4123+                 //  as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
4124+                 //  in reordering.
4125+                 if  (!g_ggml_sycl_use_async_mem_op) {
4126+                     GGML_LOG_INFO (
4127+                         " %s: disabling SYCL graphs due to unsupported node type when using a compiler without the " 
4128+                         " oneAPI async memory allocation extension " 
4129+                         " %s\n " 
4130+                         __func__, ggml_op_name (node_op));
4131+                     return  false ;
4132+                 }
40594133        }
40604134    }
40614135    return  true ;
0 commit comments