Skip to content

Commit db79ebf

Browse files
committed
Make plugin execution sequential in topologically sorted order
1 parent d225c3c commit db79ebf

File tree

5 files changed

+96
-112
lines changed

5 files changed

+96
-112
lines changed

cmd/epp/runner/runner.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,8 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf
459459

460460
// Add requestControl plugins
461461
r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...)
462-
// Check prepare data plugins for cycles.
463-
if _, err := r.requestControlConfig.PrepareDataPluginGraph(); err != nil {
462+
// Sort prepare data plugins in DAG order (topological sort). Also check prepare data plugins for cycles.
463+
if r.requestControlConfig.PrepareDataPluginGraph() != nil {
464464
return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies")
465465
}
466466

pkg/epp/requestcontrol/dag.go

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ limitations under the License.
1616

1717
package requestcontrol
1818

19-
import "errors"
19+
import (
20+
"errors"
21+
"slices"
22+
)
2023

2124
// buildDAG builds a dependency graph among data preparation plugins based on their
2225
// produced and consumed data keys.
@@ -50,50 +53,76 @@ func buildDAG(plugins []PrepareDataPlugin) map[string][]string {
5053
return dag
5154
}
5255

53-
// prepareDataGraph builds a DAG of data preparation plugins and checks for cycles.
56+
// prepareDataGraph builds the dependency graph and returns the plugins ordered in topological order.
5457
// If there is a cycle, it returns an error.
55-
func prepareDataGraph(plugins []PrepareDataPlugin) (map[string][]string, error) {
58+
func prepareDataGraph(plugins []PrepareDataPlugin) (map[string][]string, []PrepareDataPlugin, error) {
5659
dag := buildDAG(plugins)
57-
58-
// Check for cycles in the DAG.
59-
if cycleExistsInDAG(dag) {
60-
return nil, errors.New("cycle detected in data preparation plugin dependencies")
60+
nameToPlugin := map[string]PrepareDataPlugin{}
61+
for _, plugin := range plugins {
62+
nameToPlugin[plugin.TypedName().String()] = plugin
63+
}
64+
sortedPlugins, err := topologicalSort(dag)
65+
if err != nil {
66+
return nil, nil, err
67+
}
68+
orderedPlugins := []PrepareDataPlugin{}
69+
for _, pluginName := range sortedPlugins {
70+
orderedPlugins = append(orderedPlugins, nameToPlugin[pluginName])
6171
}
6272

63-
return dag, nil
73+
return dag, orderedPlugins, err
6474
}
6575

66-
// cycleExistsInDAG checks if there are cycles in the given directed graph represented as an adjacency list.
67-
func cycleExistsInDAG(dag map[string][]string) bool {
68-
visited := make(map[string]bool)
69-
recStack := make(map[string]bool)
76+
// TopologicalSort performs Kahn's Algorithm on a DAG.
77+
// It returns the sorted order or an error if a cycle is detected.
78+
func topologicalSort(graph map[string][]string) ([]string, error) {
79+
// 1. Initialize in-degree map
80+
inDegree := make(map[string]int)
7081

71-
var dfs func(string) bool
72-
dfs = func(node string) bool {
73-
if recStack[node] {
74-
return true // Cycle detected
82+
// Ensure all nodes are present in the inDegree map, even those with no dependencies
83+
for u, neighbors := range graph {
84+
if _, ok := inDegree[u]; !ok {
85+
inDegree[u] = 0
7586
}
76-
if visited[node] {
77-
return false
87+
for _, v := range neighbors {
88+
inDegree[v]++ // Increment in-degree for the destination node
7889
}
79-
visited[node] = true
80-
recStack[node] = true
90+
}
8191

82-
for _, neighbor := range dag[node] {
83-
if dfs(neighbor) {
84-
return true
85-
}
92+
// 2. Initialize the queue with nodes having 0 in-degree
93+
var queue []string
94+
for node, degree := range inDegree {
95+
if degree == 0 {
96+
queue = append(queue, node)
8697
}
87-
recStack[node] = false
88-
return false
8998
}
9099

91-
for pluginName := range dag {
92-
if !visited[pluginName] {
93-
if dfs(pluginName) {
94-
return true
100+
var result []string
101+
102+
// 3. Process the queue
103+
for len(queue) > 0 {
104+
// Dequeue
105+
u := queue[0]
106+
queue = queue[1:]
107+
108+
result = append(result, u)
109+
110+
// Decrease in-degree of neighbors
111+
if neighbors, ok := graph[u]; ok {
112+
for _, v := range neighbors {
113+
inDegree[v]--
114+
if inDegree[v] == 0 {
115+
queue = append(queue, v)
116+
}
95117
}
96118
}
97119
}
98-
return false
120+
121+
// 4. Check for cycles
122+
// If the result size != total nodes, there is a cycle
123+
if len(result) != len(inDegree) {
124+
return nil, errors.New("cycle detected: graph is not a DAG")
125+
}
126+
slices.Reverse(result)
127+
return result, nil
99128
}

pkg/epp/requestcontrol/dag_test.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package requestcontrol
1818

1919
import (
2020
"context"
21+
"maps"
2122
"testing"
2223

2324
"github.com/google/go-cmp/cmp"
@@ -82,7 +83,7 @@ func TestPrepareDataGraph(t *testing.T) {
8283
expectError: false,
8384
},
8485
{
85-
name: "Simple linear dependency (A -> B -> C)",
86+
name: "Simple linear dependency (C -> B -> A)",
8687
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC},
8788
expectedDAG: map[string][]string{
8889
"A/mock": {},
@@ -92,7 +93,7 @@ func TestPrepareDataGraph(t *testing.T) {
9293
expectError: false,
9394
},
9495
{
95-
name: "DAG with multiple dependencies (A -> B, A -> D)",
96+
name: "DAG with multiple dependencies (B -> A, D -> A, E independent)",
9697
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE},
9798
expectedDAG: map[string][]string{
9899
"A/mock": {},
@@ -108,17 +109,11 @@ func TestPrepareDataGraph(t *testing.T) {
108109
expectedDAG: nil,
109110
expectError: true,
110111
},
111-
{
112-
name: "Complex graph with a cycle",
113-
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginX, pluginY},
114-
expectedDAG: nil,
115-
expectError: true,
116-
},
117112
}
118113

119114
for _, tc := range testCases {
120115
t.Run(tc.name, func(t *testing.T) {
121-
dag, err := prepareDataGraph(tc.plugins)
116+
dag, orderedPlugins, err := prepareDataGraph(tc.plugins)
122117

123118
if tc.expectError {
124119
assert.Error(t, err)
@@ -129,9 +124,7 @@ func TestPrepareDataGraph(t *testing.T) {
129124

130125
// Normalize the slices in the maps for consistent comparison
131126
normalizedDAG := make(map[string][]string)
132-
for k, v := range dag {
133-
normalizedDAG[k] = v
134-
}
127+
maps.Copy(normalizedDAG, dag)
135128
normalizedExpectedDAG := make(map[string][]string)
136129
for k, v := range tc.expectedDAG {
137130
normalizedExpectedDAG[k] = v
@@ -140,7 +133,27 @@ func TestPrepareDataGraph(t *testing.T) {
140133
if diff := cmp.Diff(normalizedExpectedDAG, normalizedDAG); diff != "" {
141134
t.Errorf("prepareDataGraph() mismatch (-want +got):\n%s", diff)
142135
}
136+
137+
orderedPluginNames := make([]string, len(orderedPlugins))
138+
for i, p := range orderedPlugins {
139+
orderedPluginNames[i] = p.TypedName().String()
140+
}
141+
assertTopologicalOrder(t, dag, orderedPlugins)
143142
}
144143
})
145144
}
146145
}
146+
147+
func assertTopologicalOrder(t *testing.T, dag map[string][]string, ordered []PrepareDataPlugin) {
148+
t.Helper()
149+
positions := make(map[string]int)
150+
for i, p := range ordered {
151+
positions[p.TypedName().String()] = i
152+
}
153+
154+
for node, dependencies := range dag {
155+
for _, dep := range dependencies {
156+
assert.Less(t, positions[dep], positions[node], "Dependency %s should come before %s", dep, node)
157+
}
158+
}
159+
}

pkg/epp/requestcontrol/plugin_executor.go

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package requestcontrol
1919
import (
2020
"context"
2121
"errors"
22-
"fmt"
2322
"time"
2423

2524
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -29,63 +28,9 @@ import (
2928
// So, a plugin is executed only after all its dependencies have been executed.
3029
// If there is a cycle or any plugin fails with error, it returns an error.
3130
func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
32-
// Build the DAG
33-
// The error validation happens on startup when loading the config. So, here there should not be any error.
34-
dag, err := prepareDataGraph(plugins)
35-
if err != nil {
36-
return err
37-
}
38-
// Create a readonly map of plugin name to plugin instance.
39-
nameToNode := map[string]PrepareDataPlugin{}
4031
for _, plugin := range plugins {
41-
nameToNode[plugin.TypedName().String()] = plugin
42-
}
43-
// Execute the DAG
44-
45-
// Channels to signal plugin execution completion.
46-
pluginExecuted := make(map[string]chan error)
47-
// The capacity of the channel is equal to the number of dependents + 1 (for itself).
48-
capacityMap := make(map[string]int)
49-
for pluginName := range dag {
50-
capacityMap[pluginName]++
51-
for _, dep := range dag[pluginName] {
52-
capacityMap[dep]++
53-
}
54-
}
55-
for pluginName, capacity := range capacityMap {
56-
pluginExecuted[pluginName] = make(chan error, capacity)
57-
}
58-
for pluginName, dependents := range dag {
59-
// Execute plugins based on dependencies.
60-
// Wait for the dependencies to complete before executing a plugin.
61-
go func(pN string, depds []string) {
62-
for _, dep := range depds {
63-
// Wait for the dependency plugin to signal completion.
64-
if err, closed := <-pluginExecuted[dep]; closed {
65-
if err != nil {
66-
for range cap(pluginExecuted[pN]) {
67-
// Notify all dependents about the failure.
68-
pluginExecuted[pN] <- fmt.Errorf("dependency plugin %s failed: %w", dep, err)
69-
}
70-
// Do not execute this plugin as one of its dependencies failed.
71-
return
72-
}
73-
}
74-
}
75-
res := nameToNode[pN].PrepareRequestData(ctx, request, pods)
76-
for range cap(pluginExecuted[pN]) {
77-
// Notify all dependents about the completion.
78-
pluginExecuted[pN] <- res
79-
}
80-
}(pluginName, dependents)
81-
}
82-
83-
// Check for errors in plugin execution.
84-
// This will also ensure that all plugins have completed execution before returning.
85-
for pluginName := range dag {
86-
err := <-pluginExecuted[pluginName]
87-
if err != nil {
88-
return fmt.Errorf("prepare data plugin %s failed: %v", pluginName, err)
32+
if err := plugin.PrepareRequestData(ctx, request, pods); err != nil {
33+
return errors.New("prepare data plugin " + plugin.TypedName().String() + " failed: " + err.Error())
8934
}
9035
}
9136
return nil

pkg/epp/requestcontrol/request_control_config.go

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ type Config struct {
4040
responseReceivedPlugins []ResponseReceived
4141
responseStreamingPlugins []ResponseStreaming
4242
responseCompletePlugins []ResponseComplete
43-
prepareDataPluginGraph map[string][]string
4443
}
4544

4645
// WithPreRequestPlugins sets the given plugins as the PreRequest plugins.
@@ -106,16 +105,14 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) {
106105
}
107106
}
108107

109-
// PrepareDataPluginGraph creates and returns the data dependency graph of PrepareData plugins.
108+
// PrepareDataPluginGraph creates data dependency graph and sorts the plugins in topological order.
110109
// If a cycle is detected, it returns an error.
111-
func (c *Config) PrepareDataPluginGraph() (map[string][]string, error) {
112-
if c.prepareDataPluginGraph != nil {
113-
return c.prepareDataPluginGraph, nil
114-
}
115-
graph, err := prepareDataGraph(c.prepareDataPlugins)
110+
func (c *Config) PrepareDataPluginGraph() error {
111+
_, plugins, err := prepareDataGraph(c.prepareDataPlugins)
116112
if err != nil {
117-
return nil, err
113+
return err
118114
}
119-
c.prepareDataPluginGraph = graph
120-
return graph, nil
115+
c.prepareDataPlugins = plugins
116+
117+
return nil
121118
}

0 commit comments

Comments
 (0)