@@ -3809,11 +3809,43 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc
38093809 }
38103810}
38113811
3812+ #ifdef GGML_SYCL_GRAPH
3813+ static bool check_graph_compatibility (ggml_cgraph * cgraph) {
3814+ if (ggml_sycl_info ().device_count > 1 ) {
3815+ // A sycl_ex::command_graph object can only be created for a single device
3816+ GGML_LOG_INFO (" %s: disabling SYCL graphs due to multiple devices\n " , __func__);
3817+ return false ;
3818+ }
3819+
3820+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
3821+ const ggml_op node_op = cgraph->nodes [i]->op ;
3822+ switch (node_op) {
3823+ default :
3824+ break ;
3825+ case GGML_OP_CONCAT:
3826+ // ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
3827+ // but wait() can't be called on the events returned by a queue recording
3828+ // to a graph.
3829+ [[fallthrough]];
3830+ case GGML_OP_MUL_MAT_ID:
3831+ // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
3832+ // submitting a memcpy operation, but wait() can't be called on a queue that
3833+ // is recording to a graph.
3834+ GGML_LOG_INFO (" %s: disabling SYCL graphs due to unsupported node type %s\n " , __func__,
3835+ ggml_op_name (node_op));
3836+ return false ;
3837+ }
3838+ }
3839+ return true ;
3840+ }
3841+ #endif
3842+
38123843static ggml_status ggml_backend_sycl_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
38133844 auto * sycl_ctx = static_cast <ggml_backend_sycl_context *>(backend->context );
38143845
38153846#ifdef GGML_SYCL_GRAPH
3816- if (!g_ggml_sycl_disable_graph) {
3847+ bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility (cgraph);
3848+ if (use_sycl_graph) {
38173849 const bool graph_support = dpct::get_device (sycl_ctx->device ).has (sycl::aspect::ext_oneapi_limited_graph);
38183850 if (!graph_support) {
38193851 GGML_SYCL_DEBUG (" [SYCL-GRAPH] can not use graphs on device:%d\n " , sycl_ctx->device );
0 commit comments