@@ -327,6 +327,18 @@ void ggml_backend_synchronize(ggml_backend_t backend) {
327327 backend->iface .synchronize (backend);
328328}
329329
330+ bool ggml_backend_supports_graph_plan (ggml_backend_t backend) {
331+ GGML_ASSERT (backend);
332+
333+ return (bool ) backend->iface .graph_plan_create ;
334+ }
335+
336+ bool ggml_backend_supports_graph_plan_update (ggml_backend_t backend) {
337+ GGML_ASSERT (backend);
338+
339+ return (bool ) backend->iface .graph_plan_update ;
340+ }
341+
330342ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph) {
331343 GGML_ASSERT (backend);
332344 GGML_ASSERT (backend->iface .graph_plan_create != NULL );
@@ -341,6 +353,13 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
341353 backend->iface .graph_plan_free (backend, plan);
342354}
343355
356+ void ggml_backend_graph_plan_update (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph) {
357+ GGML_ASSERT (backend);
358+ GGML_ASSERT (backend->iface .graph_plan_update != NULL );
359+
360+ backend->iface .graph_plan_update (backend, plan, cgraph);
361+ }
362+
344363enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
345364 GGML_ASSERT (backend);
346365 GGML_ASSERT (backend->iface .graph_plan_compute != NULL );
@@ -675,6 +694,11 @@ struct ggml_backend_sched_split {
675694 struct ggml_cgraph graph;
676695};
677696
697+ struct ggml_backend_sched_plan {
698+ int backend_id;
699+ ggml_backend_graph_plan_t plan;
700+ };
701+
678702struct ggml_backend_sched {
679703 bool is_reset; // true if the scheduler has been reset since the last graph split
680704 bool is_alloc;
@@ -704,6 +728,12 @@ struct ggml_backend_sched {
704728 int n_splits;
705729 int splits_capacity;
706730
731+ // graph plans
732+ struct ggml_backend_sched_plan * plans;
733+ int n_plans;
734+ int plans_capacity;
735+ bool plan_dirty;
736+
707737 // pipeline parallelism support
708738 int n_copies;
709739 int cur_copy;
@@ -914,6 +944,16 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
914944 }
915945}
916946
947+ static void ggml_backend_sched_free_plans (ggml_backend_sched_t sched) {
948+ for (int i = 0 ; i < sched->n_plans ; i++) {
949+ ggml_backend_t backend = sched->backends [sched->plans [i].backend_id ];
950+ if (ggml_backend_supports_graph_plan (backend)) {
951+ ggml_backend_graph_plan_free (backend, sched->plans [i].plan );
952+ }
953+ }
954+ sched->n_plans = 0 ;
955+ }
956+
917957// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
918958void ggml_backend_sched_split_graph (ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
919959 // reset splits
@@ -1378,6 +1418,7 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
13781418 assert (graph_copy->size > graph_copy->n_leafs );
13791419 graph_copy->leafs [graph_copy->n_leafs ++] = leaf;
13801420 }
1421+ sched->plan_dirty = true ;
13811422}
13821423
13831424static bool ggml_backend_sched_alloc_splits (ggml_backend_sched_t sched) {
@@ -1419,6 +1460,62 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
14191460 return true ;
14201461}
14211462
1463+ static void ggml_backend_sched_update_plans (ggml_backend_sched_t sched) {
1464+ // create graph plans
1465+ if (sched->plan_dirty ) {
1466+ bool create_new_plans;
1467+ if (sched->n_plans == sched->n_splits ) {
1468+ create_new_plans = false ;
1469+ for (int i = 0 ; i < sched->n_splits ; i++) {
1470+ if (sched->splits [i].backend_id != sched->plans [i].backend_id ) {
1471+ create_new_plans = true ;
1472+ break ;
1473+ }
1474+ }
1475+ } else {
1476+ create_new_plans = true ;
1477+ }
1478+ if (create_new_plans) {
1479+ // free previous and recreate new plans
1480+ ggml_backend_sched_free_plans (sched);
1481+ if (sched->plans_capacity < sched->n_splits ) {
1482+ while (sched->plans_capacity < sched->n_splits ) {
1483+ sched->plans_capacity *= 2 ;
1484+ }
1485+ sched->plans = (ggml_backend_sched_plan *) realloc (
1486+ sched->plans , sched->plans_capacity * sizeof (struct ggml_backend_sched_plan ));
1487+ GGML_ASSERT (sched->plans );
1488+ }
1489+ sched->n_plans = sched->n_splits ;
1490+ for (int i = 0 ; i < sched->n_splits ; i++) {
1491+ ggml_backend_t backend = sched->backends [sched->splits [i].backend_id ];
1492+ sched->plans [i].backend_id = sched->splits [i].backend_id ;
1493+ if (ggml_backend_supports_graph_plan (backend)) {
1494+ sched->plans [i].plan = ggml_backend_graph_plan_create (backend, &sched->splits [i].graph );
1495+ } else {
1496+ sched->plans [i].plan = nullptr ;
1497+ }
1498+ }
1499+ } else {
1500+ // update existing plans
1501+ for (int i = 0 ; i < sched->n_splits ; i++) {
1502+ ggml_backend_t backend = sched->backends [sched->splits [i].backend_id ];
1503+ if (ggml_backend_supports_graph_plan (backend)) {
1504+ if (ggml_backend_supports_graph_plan_update (backend)) {
1505+ ggml_backend_graph_plan_update (backend, sched->plans [i].plan , &sched->splits [i].graph );
1506+ } else {
1507+ ggml_backend_graph_plan_free (backend, sched->plans [i].plan );
1508+ sched->plans [i].plan = ggml_backend_graph_plan_create (backend, &sched->splits [i].graph );
1509+ }
1510+ } else {
1511+ sched->plans [i].plan = nullptr ;
1512+ }
1513+ }
1514+ }
1515+ sched->plan_dirty = false ;
1516+ }
1517+ }
1518+
14221519static enum ggml_status ggml_backend_sched_compute_splits (ggml_backend_sched_t sched) {
14231520 GGML_ASSERT (sched);
14241521 struct ggml_backend_sched_split * splits = sched->splits ;
@@ -1427,6 +1524,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
14271524 std::vector<int32_t > ids;
14281525 std::vector<ggml_bitset_t > used_ids;
14291526
1527+ ggml_backend_sched_update_plans (sched);
1528+
14301529 for (int split_id = 0 ; split_id < sched->n_splits ; split_id++) {
14311530 struct ggml_backend_sched_split * split = &splits[split_id];
14321531 int split_backend_id = split->backend_id ;
@@ -1556,7 +1655,12 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
15561655 }
15571656
15581657 if (!sched->callback_eval ) {
1559- enum ggml_status ec = ggml_backend_graph_compute_async (split_backend, &split->graph );
1658+ enum ggml_status ec;
1659+ if (ggml_backend_supports_graph_plan (split_backend) && sched->plans [split_id].plan ) {
1660+ ec = ggml_backend_graph_plan_compute (split_backend, sched->plans [split_id].plan );
1661+ } else {
1662+ ec = ggml_backend_graph_compute_async (split_backend, &split->graph );
1663+ }
15601664 if (ec != GGML_STATUS_SUCCESS) {
15611665 return ec;
15621666 }
@@ -1643,6 +1747,10 @@ ggml_backend_sched_t ggml_backend_sched_new(
16431747 sched->splits = (ggml_backend_sched_split *) calloc (initial_splits_capacity, sizeof (sched->splits [0 ]));
16441748 sched->splits_capacity = initial_splits_capacity;
16451749
1750+ const int initial_plans_capacity = 16 ;
1751+ sched->plans = (ggml_backend_sched_plan *) calloc (initial_plans_capacity, sizeof (sched->plans [0 ]));
1752+ sched->plans_capacity = initial_plans_capacity;
1753+
16461754 for (int b = 0 ; b < n_backends; b++) {
16471755 sched->backends [b] = backends[b];
16481756 sched->bufts [b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type (backends[b]);
@@ -1676,6 +1784,8 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
16761784 ggml_free (sched->ctx );
16771785 ggml_hash_set_free (&sched->hash_set );
16781786 free (sched->splits );
1787+ ggml_backend_sched_free_plans (sched);
1788+ free (sched->plans );
16791789 free (sched->hv_tensor_backend_ids );
16801790 free (sched->hv_tensor_copies );
16811791 free (sched->node_backend_ids );
0 commit comments