@@ -327,6 +327,18 @@ void ggml_backend_synchronize(ggml_backend_t backend) {
327327 backend->iface .synchronize (backend);
328328}
329329
330+ bool ggml_backend_supports_graph_plan (ggml_backend_t backend) {
331+ GGML_ASSERT (backend);
332+
333+ return (bool ) backend->iface .graph_plan_create ;
334+ }
335+
336+ bool ggml_backend_supports_graph_plan_update (ggml_backend_t backend) {
337+ GGML_ASSERT (backend);
338+
339+ return (bool ) backend->iface .graph_plan_update ;
340+ }
341+
330342ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph) {
331343 GGML_ASSERT (backend);
332344 GGML_ASSERT (backend->iface .graph_plan_create != NULL );
@@ -341,6 +353,13 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
341353 backend->iface .graph_plan_free (backend, plan);
342354}
343355
356+ void ggml_backend_graph_plan_update (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph) {
357+ GGML_ASSERT (backend);
358+ GGML_ASSERT (backend->iface .graph_plan_update != NULL );
359+
360+ backend->iface .graph_plan_update (backend, plan, cgraph);
361+ }
362+
344363enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
345364 GGML_ASSERT (backend);
346365 GGML_ASSERT (backend->iface .graph_plan_compute != NULL );
@@ -675,6 +694,11 @@ struct ggml_backend_sched_split {
675694 struct ggml_cgraph graph;
676695};
677696
697+ struct ggml_backend_sched_plan {
698+ int backend_id;
699+ ggml_backend_graph_plan_t plan;
700+ };
701+
678702struct ggml_backend_sched {
679703 bool is_reset; // true if the scheduler has been reset since the last graph split
680704 bool is_alloc;
@@ -704,6 +728,12 @@ struct ggml_backend_sched {
704728 int n_splits;
705729 int splits_capacity;
706730
731+ // graph plans
732+ struct ggml_backend_sched_plan * plans;
733+ int n_plans;
734+ int plans_capacity;
735+ bool plan_dirty;
736+
707737 // pipeline parallelism support
708738 int n_copies;
709739 int cur_copy;
@@ -908,6 +938,16 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
908938 }
909939}
910940
941+ static void ggml_backend_sched_free_plans (ggml_backend_sched_t sched) {
942+ for (int i = 0 ; i < sched->n_plans ; i++) {
943+ ggml_backend_t backend = sched->backends [sched->plans [i].backend_id ];
944+ if (ggml_backend_supports_graph_plan (backend)) {
945+ ggml_backend_graph_plan_free (backend, sched->plans [i].plan );
946+ }
947+ }
948+ sched->n_plans = 0 ;
949+ }
950+
911951// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
912952void ggml_backend_sched_split_graph (ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
913953 // reset splits
@@ -1372,6 +1412,7 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
13721412 assert (graph_copy->size > graph_copy->n_leafs );
13731413 graph_copy->leafs [graph_copy->n_leafs ++] = leaf;
13741414 }
1415+ sched->plan_dirty = true ;
13751416}
13761417
13771418static bool ggml_backend_sched_alloc_splits (ggml_backend_sched_t sched) {
@@ -1413,6 +1454,62 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
14131454 return true ;
14141455}
14151456
1457+ static void ggml_backend_sched_update_plans (ggml_backend_sched_t sched) {
1458+ // create graph plans
1459+ if (sched->plan_dirty ) {
1460+ bool create_new_plans;
1461+ if (sched->n_plans == sched->n_splits ) {
1462+ create_new_plans = false ;
1463+ for (int i = 0 ; i < sched->n_splits ; i++) {
1464+ if (sched->splits [i].backend_id != sched->plans [i].backend_id ) {
1465+ create_new_plans = true ;
1466+ break ;
1467+ }
1468+ }
1469+ } else {
1470+ create_new_plans = true ;
1471+ }
1472+ if (create_new_plans) {
1473+ // free previous and recreate new plans
1474+ ggml_backend_sched_free_plans (sched);
1475+ if (sched->plans_capacity < sched->n_splits ) {
1476+ while (sched->plans_capacity < sched->n_splits ) {
1477+ sched->plans_capacity *= 2 ;
1478+ }
1479+ sched->plans = (ggml_backend_sched_plan *) realloc (
1480+ sched->plans , sched->plans_capacity * sizeof (struct ggml_backend_sched_plan ));
1481+ GGML_ASSERT (sched->plans );
1482+ }
1483+ sched->n_plans = sched->n_splits ;
1484+ for (int i = 0 ; i < sched->n_splits ; i++) {
1485+ ggml_backend_t backend = sched->backends [sched->splits [i].backend_id ];
1486+ sched->plans [i].backend_id = sched->splits [i].backend_id ;
1487+ if (ggml_backend_supports_graph_plan (backend)) {
1488+ sched->plans [i].plan = ggml_backend_graph_plan_create (backend, &sched->splits [i].graph );
1489+ } else {
1490+ sched->plans [i].plan = nullptr ;
1491+ }
1492+ }
1493+ } else {
1494+ // update existing plans
1495+ for (int i = 0 ; i < sched->n_splits ; i++) {
1496+ ggml_backend_t backend = sched->backends [sched->splits [i].backend_id ];
1497+ if (ggml_backend_supports_graph_plan (backend)) {
1498+ if (ggml_backend_supports_graph_plan_update (backend)) {
1499+ ggml_backend_graph_plan_update (backend, sched->plans [i].plan , &sched->splits [i].graph );
1500+ } else {
1501+ ggml_backend_graph_plan_free (backend, sched->plans [i].plan );
1502+ sched->plans [i].plan = ggml_backend_graph_plan_create (backend, &sched->splits [i].graph );
1503+ }
1504+ } else {
1505+ sched->plans [i].plan = nullptr ;
1506+ }
1507+ }
1508+ }
1509+ sched->plan_dirty = false ;
1510+ }
1511+ }
1512+
14161513static enum ggml_status ggml_backend_sched_compute_splits (ggml_backend_sched_t sched) {
14171514 GGML_ASSERT (sched);
14181515 struct ggml_backend_sched_split * splits = sched->splits ;
@@ -1421,6 +1518,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
14211518 std::vector<int32_t > ids;
14221519 std::vector<ggml_bitset_t > used_ids;
14231520
1521+ ggml_backend_sched_update_plans (sched);
1522+
14241523 for (int split_id = 0 ; split_id < sched->n_splits ; split_id++) {
14251524 struct ggml_backend_sched_split * split = &splits[split_id];
14261525 int split_backend_id = split->backend_id ;
@@ -1550,7 +1649,12 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
15501649 }
15511650
15521651 if (!sched->callback_eval ) {
1553- enum ggml_status ec = ggml_backend_graph_compute_async (split_backend, &split->graph );
1652+ enum ggml_status ec;
1653+ if (ggml_backend_supports_graph_plan (split_backend) && sched->plans [split_id].plan ) {
1654+ ec = ggml_backend_graph_plan_compute (split_backend, sched->plans [split_id].plan );
1655+ } else {
1656+ ec = ggml_backend_graph_compute_async (split_backend, &split->graph );
1657+ }
15541658 if (ec != GGML_STATUS_SUCCESS) {
15551659 return ec;
15561660 }
@@ -1637,6 +1741,10 @@ ggml_backend_sched_t ggml_backend_sched_new(
16371741 sched->splits = (ggml_backend_sched_split *) calloc (initial_splits_capacity, sizeof (sched->splits [0 ]));
16381742 sched->splits_capacity = initial_splits_capacity;
16391743
1744+ const int initial_plans_capacity = 16 ;
1745+ sched->plans = (ggml_backend_sched_plan *) calloc (initial_plans_capacity, sizeof (sched->plans [0 ]));
1746+ sched->plans_capacity = initial_plans_capacity;
1747+
16401748 for (int b = 0 ; b < n_backends; b++) {
16411749 sched->backends [b] = backends[b];
16421750 sched->bufts [b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type (backends[b]);
@@ -1670,6 +1778,8 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
16701778 ggml_free (sched->ctx );
16711779 ggml_hash_set_free (&sched->hash_set );
16721780 free (sched->splits );
1781+ ggml_backend_sched_free_plans (sched);
1782+ free (sched->plans );
16731783 free (sched->hv_tensor_backend_ids );
16741784 free (sched->hv_tensor_copies );
16751785 free (sched->node_backend_ids );
0 commit comments