Skip to content

Commit 66bf69f

Browse files
committed
Parallelize execution of prepare data plugins as a DAG. Also detect data dependency cycles on startup.
1 parent 943e676 commit 66bf69f

File tree

6 files changed

+310
-13
lines changed

6 files changed

+310
-13
lines changed

cmd/epp/runner/runner.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,10 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf
459459

460460
// Add requestControl plugins
461461
r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...)
462+
// Check prepare data plugins for cycles.
463+
if r.requestControlConfig.ValidatePrepareDataPlugins() != nil {
464+
return errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies")
465+
}
462466

463467
// Handler deprecated configuration options
464468
r.deprecatedConfigurationHelper(cfg, logger)

pkg/epp/requestcontrol/dag.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package requestcontrol
18+
19+
import "errors"
20+
21+
// buildDAG builds a dependency graph among data preparation plugins based on their
22+
// produced and consumed data keys.
23+
func buildDAG(plugins []PrepareDataPlugin) map[string][]string {
24+
dag := make(map[string][]string)
25+
for _, plugin := range plugins {
26+
dag[plugin.TypedName().String()] = []string{}
27+
}
28+
// Create dependency graph as a DAG.
29+
for i := range plugins {
30+
for j := range plugins {
31+
if i == j {
32+
continue
33+
}
34+
// Check whether plugin[i] produces something consumed by plugin[j]. In that case, j depends on i.
35+
if plugins[i].Produces() != nil && plugins[j].Consumes() != nil {
36+
// For all the keys produced by plugin i, check if plugin j consumes any of them.
37+
// If yes, then j depends on i.
38+
for producedKey := range plugins[i].Produces() {
39+
// If plugin j consumes the produced key, then j depends on i. We can break after the first match.
40+
if _, ok := plugins[j].Consumes()[producedKey]; ok {
41+
iPluginName := plugins[i].TypedName().String()
42+
jPluginName := plugins[j].TypedName().String()
43+
dag[jPluginName] = append(dag[jPluginName], iPluginName)
44+
break
45+
}
46+
}
47+
}
48+
}
49+
}
50+
return dag
51+
}
52+
53+
// prepareDataGraph builds a DAG of data preparation plugins and checks for cycles.
54+
// If there is a cycle, it returns an error.
55+
func prepareDataGraph(plugins []PrepareDataPlugin) (map[string][]string, error) {
56+
dag := buildDAG(plugins)
57+
58+
// Check for cycles in the DAG.
59+
if cycleExistsInDAG(dag) {
60+
return nil, errors.New("cycle detected in data preparation plugin dependencies")
61+
}
62+
63+
return dag, nil
64+
}
65+
66+
// cycleExistsInDAG checks if there are cycles in the given directed graph represented as an adjacency list.
67+
func cycleExistsInDAG(dag map[string][]string) bool {
68+
visited := make(map[string]bool)
69+
recStack := make(map[string]bool)
70+
71+
var dfs func(string) bool
72+
dfs = func(node string) bool {
73+
if recStack[node] {
74+
return true // Cycle detected
75+
}
76+
if visited[node] {
77+
return false
78+
}
79+
visited[node] = true
80+
recStack[node] = true
81+
82+
for _, neighbor := range dag[node] {
83+
if dfs(neighbor) {
84+
return true
85+
}
86+
}
87+
recStack[node] = false
88+
return false
89+
}
90+
91+
for pluginName := range dag {
92+
if !visited[pluginName] {
93+
if dfs(pluginName) {
94+
return true
95+
}
96+
}
97+
}
98+
return false
99+
}

pkg/epp/requestcontrol/dag_test.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package requestcontrol
18+
19+
import (
20+
"context"
21+
"testing"
22+
23+
"github.com/google/go-cmp/cmp"
24+
"github.com/stretchr/testify/assert"
25+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
26+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
27+
)
28+
29+
type mockPrepareRequestDataP struct {
30+
name string
31+
produces map[string]any
32+
consumes map[string]any
33+
}
34+
35+
func (m *mockPrepareRequestDataP) TypedName() plugins.TypedName {
36+
return plugins.TypedName{Name: m.name, Type: "mock"}
37+
}
38+
39+
func (m *mockPrepareRequestDataP) Produces() map[string]any {
40+
return m.produces
41+
}
42+
43+
func (m *mockPrepareRequestDataP) Consumes() map[string]any {
44+
return m.consumes
45+
}
46+
47+
func (m *mockPrepareRequestDataP) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error {
48+
pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42})
49+
return nil
50+
}
51+
52+
func TestPrepareDataGraph(t *testing.T) {
53+
pluginA := &mockPrepareRequestDataP{name: "A", produces: map[string]any{"keyA": nil}}
54+
pluginB := &mockPrepareRequestDataP{name: "B", consumes: map[string]any{"keyA": nil}, produces: map[string]any{"keyB": nil}}
55+
pluginC := &mockPrepareRequestDataP{name: "C", consumes: map[string]any{"keyB": nil}}
56+
pluginD := &mockPrepareRequestDataP{name: "D", consumes: map[string]any{"keyA": nil}}
57+
pluginE := &mockPrepareRequestDataP{name: "E"} // No dependencies
58+
59+
// Cycle plugins
60+
pluginX := &mockPrepareRequestDataP{name: "X", produces: map[string]any{"keyX": nil}, consumes: map[string]any{"keyY": nil}}
61+
pluginY := &mockPrepareRequestDataP{name: "Y", produces: map[string]any{"keyY": nil}, consumes: map[string]any{"keyX": nil}}
62+
63+
testCases := []struct {
64+
name string
65+
plugins []PrepareDataPlugin
66+
expectedDAG map[string][]string
67+
expectError bool
68+
}{
69+
{
70+
name: "No plugins",
71+
plugins: []PrepareDataPlugin{},
72+
expectedDAG: map[string][]string{},
73+
expectError: false,
74+
},
75+
{
76+
name: "Plugins with no dependencies",
77+
plugins: []PrepareDataPlugin{pluginA, pluginE},
78+
expectedDAG: map[string][]string{
79+
"A/mock": {},
80+
"E/mock": {},
81+
},
82+
expectError: false,
83+
},
84+
{
85+
name: "Simple linear dependency (A -> B -> C)",
86+
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC},
87+
expectedDAG: map[string][]string{
88+
"A/mock": {},
89+
"B/mock": {"A/mock"},
90+
"C/mock": {"B/mock"},
91+
},
92+
expectError: false,
93+
},
94+
{
95+
name: "DAG with multiple dependencies (A -> B, A -> D)",
96+
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE},
97+
expectedDAG: map[string][]string{
98+
"A/mock": {},
99+
"B/mock": {"A/mock"},
100+
"D/mock": {"A/mock"},
101+
"E/mock": {},
102+
},
103+
expectError: false,
104+
},
105+
{
106+
name: "Graph with a cycle (X -> Y, Y -> X)",
107+
plugins: []PrepareDataPlugin{pluginX, pluginY},
108+
expectedDAG: nil,
109+
expectError: true,
110+
},
111+
{
112+
name: "Complex graph with a cycle",
113+
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginX, pluginY},
114+
expectedDAG: nil,
115+
expectError: true,
116+
},
117+
}
118+
119+
for _, tc := range testCases {
120+
t.Run(tc.name, func(t *testing.T) {
121+
dag, err := prepareDataGraph(tc.plugins)
122+
123+
if tc.expectError {
124+
assert.Error(t, err)
125+
assert.Nil(t, dag)
126+
assert.Contains(t, err.Error(), "cycle detected")
127+
} else {
128+
assert.NoError(t, err)
129+
130+
// Normalize the slices in the maps for consistent comparison
131+
normalizedDAG := make(map[string][]string)
132+
for k, v := range dag {
133+
normalizedDAG[k] = v
134+
}
135+
normalizedExpectedDAG := make(map[string][]string)
136+
for k, v := range tc.expectedDAG {
137+
normalizedExpectedDAG[k] = v
138+
}
139+
140+
if diff := cmp.Diff(normalizedExpectedDAG, normalizedDAG); diff != "" {
141+
t.Errorf("prepareDataGraph() mismatch (-want +got):\n%s", diff)
142+
}
143+
}
144+
})
145+
}
146+
}

pkg/epp/requestcontrol/director.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -353,13 +353,9 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
353353
}
354354
}
355355

356-
// TODO: Execute plugins in parallel once DAG execution is supported.
357-
// runPrepareDataPlugins executes PrepareDataPlugins sequentially.
358356
func (d *Director) runPrepareDataPlugins(ctx context.Context,
359357
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
360-
return prepareDataPluginsWithTimeout(
361-
prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods)
362-
358+
return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods)
363359
}
364360

365361
func (d *Director) runAdmissionPlugins(ctx context.Context,

pkg/epp/requestcontrol/plugin_executor.go

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,68 @@ package requestcontrol
1919
import (
2020
"context"
2121
"errors"
22+
"fmt"
2223
"time"
2324

2425
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2526
)
2627

28+
// executePluginsAsDAG executes PrepareData plugins as a DAG based on their dependencies asynchronously.
29+
// So, a plugin is executed only after all its dependencies have been executed.
30+
// If there is a cycle or other error in the DAG, it returns an error.
31+
func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
32+
// Build the DAG
33+
// The error validation happens on startup when loading the config. So, here there should not be any error.
34+
dag, err := prepareDataGraph(plugins)
35+
if err != nil {
36+
return err
37+
}
38+
// Execute the DAG
39+
40+
// Channels to signal plugin execution completion.
41+
pluginExecuted := make(map[string]chan error)
42+
nameToNode := map[string]PrepareDataPlugin{}
43+
for _, plugin := range plugins {
44+
pluginExecuted[plugin.TypedName().String()] = make(chan error)
45+
nameToNode[plugin.TypedName().String()] = plugin
46+
}
47+
48+
for pluginName, dependents := range dag {
49+
// Execute plugins based on dependencies.
50+
// Wait for the dependencies to complete before executing a plugin.
51+
go func() {
52+
for _, dep := range dependents {
53+
err, open := <-pluginExecuted[dep]
54+
if !open {
55+
continue
56+
}
57+
if err != nil {
58+
// If a dependency failed, propagate the error and do not execute this plugin.
59+
pluginExecuted[pluginName] <- fmt.Errorf("dependency plugin %s failed: %w", dep, err)
60+
}
61+
}
62+
// Signal that the plugin has been executed.
63+
defer close(pluginExecuted[pluginName])
64+
65+
pluginExecuted[pluginName] <- nameToNode[pluginName].PrepareRequestData(ctx, request, pods)
66+
}()
67+
}
68+
for pluginName := range dag {
69+
err := <-pluginExecuted[pluginName]
70+
if err != nil {
71+
return errors.New("prepare data plugin " + pluginName + " failed: " + err.Error())
72+
}
73+
}
74+
return nil
75+
}
76+
2777
// prepareDataPluginsWithTimeout executes the PrepareRequestData plugins with retries and timeout.
2878
func prepareDataPluginsWithTimeout(timeout time.Duration, plugins []PrepareDataPlugin,
2979
ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
3080
errCh := make(chan error, 1)
3181
// Execute plugins sequentially in a separate goroutine
3282
go func() {
33-
for _, plugin := range plugins {
34-
err := plugin.PrepareRequestData(ctx, request, pods)
35-
if err != nil {
36-
errCh <- errors.New("prepare data plugin " + plugin.TypedName().String() + " failed: " + err.Error())
37-
return
38-
}
39-
}
40-
errCh <- nil
83+
errCh <- executePluginsAsDAG(plugins, ctx, request, pods)
4184
}()
4285

4386
select {

pkg/epp/requestcontrol/request_control_config.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,12 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) {
104104
}
105105
}
106106
}
107+
108+
// ValidatePrepareDataPlugins validates the PrepareData plugins in the Config.
109+
// It builds the data dependency graph and checks for cycles.
110+
// If a cycle is detected, it returns an error.
111+
func (c *Config) ValidatePrepareDataPlugins() error {
112+
_, err := prepareDataGraph(c.prepareDataPlugins)
113+
114+
return err
115+
}

0 commit comments

Comments
 (0)