From 66bf69fc58b585849b68f4edd7691f4af6ba69cf Mon Sep 17 00:00:00 2001 From: Rahul Gurnani Date: Tue, 11 Nov 2025 20:09:10 +0000 Subject: [PATCH 1/5] Parallelize execution of prepare data plugins as a DAG. Also detect data dependency cycles on startup. --- cmd/epp/runner/runner.go | 4 + pkg/epp/requestcontrol/dag.go | 99 ++++++++++++ pkg/epp/requestcontrol/dag_test.go | 146 ++++++++++++++++++ pkg/epp/requestcontrol/director.go | 6 +- pkg/epp/requestcontrol/plugin_executor.go | 59 ++++++- .../requestcontrol/request_control_config.go | 9 ++ 6 files changed, 310 insertions(+), 13 deletions(-) create mode 100644 pkg/epp/requestcontrol/dag.go create mode 100644 pkg/epp/requestcontrol/dag_test.go diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index e6caafa2d..55a6c37b1 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -459,6 +459,10 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf // Add requestControl plugins r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...) + // Check prepare data plugins for cycles. + if r.requestControlConfig.ValidatePrepareDataPlugins() != nil { + return errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies") + } // Handler deprecated configuration options r.deprecatedConfigurationHelper(cfg, logger) diff --git a/pkg/epp/requestcontrol/dag.go b/pkg/epp/requestcontrol/dag.go new file mode 100644 index 000000000..891e36e85 --- /dev/null +++ b/pkg/epp/requestcontrol/dag.go @@ -0,0 +1,99 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package requestcontrol + +import "errors" + +// buildDAG builds a dependency graph among data preparation plugins based on their +// produced and consumed data keys. +func buildDAG(plugins []PrepareDataPlugin) map[string][]string { + dag := make(map[string][]string) + for _, plugin := range plugins { + dag[plugin.TypedName().String()] = []string{} + } + // Create dependency graph as a DAG. + for i := range plugins { + for j := range plugins { + if i == j { + continue + } + // Check whether plugin[i] produces something consumed by plugin[j]. In that case, j depends on i. + if plugins[i].Produces() != nil && plugins[j].Consumes() != nil { + // For all the keys produced by plugin i, check if plugin j consumes any of them. + // If yes, then j depends on i. + for producedKey := range plugins[i].Produces() { + // If plugin j consumes the produced key, then j depends on i. We can break after the first match. + if _, ok := plugins[j].Consumes()[producedKey]; ok { + iPluginName := plugins[i].TypedName().String() + jPluginName := plugins[j].TypedName().String() + dag[jPluginName] = append(dag[jPluginName], iPluginName) + break + } + } + } + } + } + return dag +} + +// prepareDataGraph builds a DAG of data preparation plugins and checks for cycles. +// If there is a cycle, it returns an error. +func prepareDataGraph(plugins []PrepareDataPlugin) (map[string][]string, error) { + dag := buildDAG(plugins) + + // Check for cycles in the DAG. + if cycleExistsInDAG(dag) { + return nil, errors.New("cycle detected in data preparation plugin dependencies") + } + + return dag, nil +} + +// cycleExistsInDAG checks if there are cycles in the given directed graph represented as an adjacency list. +func cycleExistsInDAG(dag map[string][]string) bool { + visited := make(map[string]bool) + recStack := make(map[string]bool) + + var dfs func(string) bool + dfs = func(node string) bool { + if recStack[node] { + return true // Cycle detected + } + if visited[node] { + return false + } + visited[node] = true + recStack[node] = true + + for _, neighbor := range dag[node] { + if dfs(neighbor) { + return true + } + } + recStack[node] = false + return false + } + + for pluginName := range dag { + if !visited[pluginName] { + if dfs(pluginName) { + return true + } + } + } + return false +} diff --git a/pkg/epp/requestcontrol/dag_test.go b/pkg/epp/requestcontrol/dag_test.go new file mode 100644 index 000000000..a62bd3793 --- /dev/null +++ b/pkg/epp/requestcontrol/dag_test.go @@ -0,0 +1,146 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package requestcontrol + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +type mockPrepareRequestDataP struct { + name string + produces map[string]any + consumes map[string]any +} + +func (m *mockPrepareRequestDataP) TypedName() plugins.TypedName { + return plugins.TypedName{Name: m.name, Type: "mock"} +} + +func (m *mockPrepareRequestDataP) Produces() map[string]any { + return m.produces +} + +func (m *mockPrepareRequestDataP) Consumes() map[string]any { + return m.consumes +} + +func (m *mockPrepareRequestDataP) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error { + pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42}) + return nil +} + +func TestPrepareDataGraph(t *testing.T) { + pluginA := &mockPrepareRequestDataP{name: "A", produces: map[string]any{"keyA": nil}} + pluginB := &mockPrepareRequestDataP{name: "B", consumes: map[string]any{"keyA": nil}, produces: map[string]any{"keyB": nil}} + pluginC := &mockPrepareRequestDataP{name: "C", consumes: map[string]any{"keyB": nil}} + pluginD := &mockPrepareRequestDataP{name: "D", consumes: map[string]any{"keyA": nil}} + pluginE := &mockPrepareRequestDataP{name: "E"} // No dependencies + + // Cycle plugins + pluginX := &mockPrepareRequestDataP{name: "X", produces: map[string]any{"keyX": nil}, consumes: map[string]any{"keyY": nil}} + pluginY := &mockPrepareRequestDataP{name: "Y", produces: map[string]any{"keyY": nil}, consumes: map[string]any{"keyX": nil}} + + testCases := []struct { + name string + plugins []PrepareDataPlugin + expectedDAG map[string][]string + expectError bool + }{ + { + name: "No plugins", + plugins: []PrepareDataPlugin{}, + expectedDAG: map[string][]string{}, + expectError: false, + }, + { + name: "Plugins with no dependencies", + plugins: []PrepareDataPlugin{pluginA, pluginE}, + expectedDAG: map[string][]string{ + "A/mock": {}, + "E/mock": {}, + }, + expectError: false, + }, + { + name: "Simple linear dependency (A -> B -> C)", + plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC}, + expectedDAG: map[string][]string{ + "A/mock": {}, + "B/mock": {"A/mock"}, + "C/mock": {"B/mock"}, + }, + expectError: false, + }, + { + name: "DAG with multiple dependencies (A -> B, A -> D)", + plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE}, + expectedDAG: map[string][]string{ + "A/mock": {}, + "B/mock": {"A/mock"}, + "D/mock": {"A/mock"}, + "E/mock": {}, + }, + expectError: false, + }, + { + name: "Graph with a cycle (X -> Y, Y -> X)", + plugins: []PrepareDataPlugin{pluginX, pluginY}, + expectedDAG: nil, + expectError: true, + }, + { + name: "Complex graph with a cycle", + plugins: []PrepareDataPlugin{pluginA, pluginB, pluginX, pluginY}, + expectedDAG: nil, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dag, err := prepareDataGraph(tc.plugins) + + if tc.expectError { + assert.Error(t, err) + assert.Nil(t, dag) + assert.Contains(t, err.Error(), "cycle detected") + } else { + assert.NoError(t, err) + + // Normalize the slices in the maps for consistent comparison + normalizedDAG := make(map[string][]string) + for k, v := range dag { + normalizedDAG[k] = v + } + normalizedExpectedDAG := make(map[string][]string) + for k, v := range tc.expectedDAG { + normalizedExpectedDAG[k] = v + } + + if diff := cmp.Diff(normalizedExpectedDAG, normalizedDAG); diff != "" { + t.Errorf("prepareDataGraph() mismatch (-want +got):\n%s", diff) + } + } + }) + } +} diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index e8db59b7d..4c35627b7 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -353,13 +353,9 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling } } -// TODO: Execute plugins in parallel once DAG execution is supported. -// runPrepareDataPlugins executes PrepareDataPlugins sequentially. func (d *Director) runPrepareDataPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { - return prepareDataPluginsWithTimeout( - prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods) - + return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods) } func (d *Director) runAdmissionPlugins(ctx context.Context, diff --git a/pkg/epp/requestcontrol/plugin_executor.go b/pkg/epp/requestcontrol/plugin_executor.go index be8d4ba39..1fbdc5c97 100644 --- a/pkg/epp/requestcontrol/plugin_executor.go +++ b/pkg/epp/requestcontrol/plugin_executor.go @@ -19,25 +19,68 @@ package requestcontrol import ( "context" "errors" + "fmt" "time" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) +// executePluginsAsDAG executes PrepareData plugins as a DAG based on their dependencies asynchronously. +// So, a plugin is executed only after all its dependencies have been executed. +// If there is a cycle or other error in the DAG, it returns an error. +func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + // Build the DAG + // The error validation happens on startup when loading the config. So, here there should not be any error. + dag, err := prepareDataGraph(plugins) + if err != nil { + return err + } + // Execute the DAG + + // Channels to signal plugin execution completion. + pluginExecuted := make(map[string]chan error) + nameToNode := map[string]PrepareDataPlugin{} + for _, plugin := range plugins { + pluginExecuted[plugin.TypedName().String()] = make(chan error) + nameToNode[plugin.TypedName().String()] = plugin + } + + for pluginName, dependents := range dag { + // Execute plugins based on dependencies. + // Wait for the dependencies to complete before executing a plugin. + go func() { + for _, dep := range dependents { + err, open := <-pluginExecuted[dep] + if !open { + continue + } + if err != nil { + // If a dependency failed, propagate the error and do not execute this plugin. + pluginExecuted[pluginName] <- fmt.Errorf("dependency plugin %s failed: %w", dep, err) + } + } + // Signal that the plugin has been executed. + defer close(pluginExecuted[pluginName]) + + pluginExecuted[pluginName] <- nameToNode[pluginName].PrepareRequestData(ctx, request, pods) + }() + } + for pluginName := range dag { + err := <-pluginExecuted[pluginName] + if err != nil { + return errors.New("prepare data plugin " + pluginName + " failed: " + err.Error()) + } + } + return nil +} + // prepareDataPluginsWithTimeout executes the PrepareRequestData plugins with retries and timeout. func prepareDataPluginsWithTimeout(timeout time.Duration, plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { errCh := make(chan error, 1) // Execute plugins sequentially in a separate goroutine go func() { - for _, plugin := range plugins { - err := plugin.PrepareRequestData(ctx, request, pods) - if err != nil { - errCh <- errors.New("prepare data plugin " + plugin.TypedName().String() + " failed: " + err.Error()) - return - } - } - errCh <- nil + errCh <- executePluginsAsDAG(plugins, ctx, request, pods) }() select { diff --git a/pkg/epp/requestcontrol/request_control_config.go b/pkg/epp/requestcontrol/request_control_config.go index 9701be999..f29fe5582 100644 --- a/pkg/epp/requestcontrol/request_control_config.go +++ b/pkg/epp/requestcontrol/request_control_config.go @@ -104,3 +104,12 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) { } } } + +// ValidatePrepareDataPlugins validates the PrepareData plugins in the Config. +// It builds the data dependency graph and checks for cycles. +// If a cycle is detected, it returns an error. +func (c *Config) ValidatePrepareDataPlugins() error { + _, err := prepareDataGraph(c.prepareDataPlugins) + + return err +} From 02dfa484516381d1bde27e469b7bd0f70bed1792 Mon Sep 17 00:00:00 2001 From: Rahul Gurnani Date: Tue, 18 Nov 2025 22:02:43 +0000 Subject: [PATCH 2/5] Use buffered channel based approach to synchronize go routines. Also add unit tests --- pkg/epp/requestcontrol/plugin_executor.go | 60 +++++--- .../requestcontrol/plugin_executor_test.go | 136 ++++++++++++++++++ 2 files changed, 174 insertions(+), 22 deletions(-) diff --git a/pkg/epp/requestcontrol/plugin_executor.go b/pkg/epp/requestcontrol/plugin_executor.go index 1fbdc5c97..1048be497 100644 --- a/pkg/epp/requestcontrol/plugin_executor.go +++ b/pkg/epp/requestcontrol/plugin_executor.go @@ -27,7 +27,7 @@ import ( // executePluginsAsDAG executes PrepareData plugins as a DAG based on their dependencies asynchronously. // So, a plugin is executed only after all its dependencies have been executed. -// If there is a cycle or other error in the DAG, it returns an error. +// If there is a cycle or any plugin fails with error, it returns an error. func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { // Build the DAG // The error validation happens on startup when loading the config. So, here there should not be any error. @@ -35,40 +35,57 @@ func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, reque if err != nil { return err } - // Execute the DAG - - // Channels to signal plugin execution completion. - pluginExecuted := make(map[string]chan error) + // Create a readonly map of plugin name to plugin instance. nameToNode := map[string]PrepareDataPlugin{} for _, plugin := range plugins { - pluginExecuted[plugin.TypedName().String()] = make(chan error) nameToNode[plugin.TypedName().String()] = plugin } + // Execute the DAG + // Channels to signal plugin execution completion. + pluginExecuted := make(map[string]chan error) + // The capacity of the channel is equal to the number of dependents + 1 (for itself). + capacityMap := make(map[string]int) + for pluginName := range dag { + capacityMap[pluginName]++ + for _, dep := range dag[pluginName] { + capacityMap[dep]++ + } + } + for pluginName, capacity := range capacityMap { + pluginExecuted[pluginName] = make(chan error, capacity) + } for pluginName, dependents := range dag { // Execute plugins based on dependencies. // Wait for the dependencies to complete before executing a plugin. - go func() { - for _, dep := range dependents { - err, open := <-pluginExecuted[dep] - if !open { - continue - } - if err != nil { - // If a dependency failed, propagate the error and do not execute this plugin. - pluginExecuted[pluginName] <- fmt.Errorf("dependency plugin %s failed: %w", dep, err) + go func(pN string, depds []string) { + for _, dep := range depds { + // Wait for the dependency plugin to signal completion. + if err, closed := <-pluginExecuted[dep]; closed { + if err != nil { + for range cap(pluginExecuted[pN]) { + // Notify all dependents about the failure. + pluginExecuted[pN] <- fmt.Errorf("dependency plugin %s failed: %w", dep, err) + } + // Do not execute this plugin as one of its dependencies failed. + return + } } } - // Signal that the plugin has been executed. - defer close(pluginExecuted[pluginName]) - - pluginExecuted[pluginName] <- nameToNode[pluginName].PrepareRequestData(ctx, request, pods) - }() + res := nameToNode[pN].PrepareRequestData(ctx, request, pods) + for range cap(pluginExecuted[pN]) { + // Notify all dependents about the completion. + pluginExecuted[pN] <- res + } + }(pluginName, dependents) } + + // Check for errors in plugin execution. + // This will also ensure that all plugins have completed execution before returning. for pluginName := range dag { err := <-pluginExecuted[pluginName] if err != nil { - return errors.New("prepare data plugin " + pluginName + " failed: " + err.Error()) + return fmt.Errorf("prepare data plugin %s failed: %v", pluginName, err) } } return nil @@ -78,7 +95,6 @@ func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, reque func prepareDataPluginsWithTimeout(timeout time.Duration, plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { errCh := make(chan error, 1) - // Execute plugins sequentially in a separate goroutine go func() { errCh <- executePluginsAsDAG(plugins, ctx, request, pods) }() diff --git a/pkg/epp/requestcontrol/plugin_executor_test.go b/pkg/epp/requestcontrol/plugin_executor_test.go index b37264f6b..1825f7a4c 100644 --- a/pkg/epp/requestcontrol/plugin_executor_test.go +++ b/pkg/epp/requestcontrol/plugin_executor_test.go @@ -19,6 +19,7 @@ package requestcontrol import ( "context" "errors" + "sync" "testing" "time" @@ -157,3 +158,138 @@ func TestPrepareDataPluginsWithTimeout(t *testing.T) { }) } } + +type dagTestPlugin struct { + mockPrepareRequestDataPlugin + produces map[string]any + consumes map[string]any + execTime time.Time + mu sync.Mutex +} + +func (p *dagTestPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + p.mu.Lock() + defer p.mu.Unlock() + p.execTime = time.Now() + return p.mockPrepareRequestDataPlugin.PrepareRequestData(ctx, request, pods) +} + +func (p *dagTestPlugin) Produces() map[string]any { + return p.produces +} + +func (p *dagTestPlugin) Consumes() map[string]any { + return p.consumes +} + +func TestExecutePluginsAsDAG(t *testing.T) { + pluginA := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "A", delay: 20 * time.Millisecond}, + produces: map[string]any{"keyA": nil}, + } + pluginB := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "B"}, + consumes: map[string]any{"keyA": nil}, + produces: map[string]any{"keyB": nil}, + } + pluginC := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "C"}, + consumes: map[string]any{"keyB": nil}, + } + pluginD := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "D"}, + consumes: map[string]any{"keyA": nil}, + } + pluginE := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "E"}, + } + pluginFail := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "Fail", returnErr: errors.New("plugin failed")}, + produces: map[string]any{"keyFail": nil}, + } + pluginDependsOnFail := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "DependsOnFail"}, + consumes: map[string]any{"keyFail": nil}, + } + + testCases := []struct { + name string + plugins []PrepareDataPlugin + expectErr bool + checkFunc func(t *testing.T, plugins []PrepareDataPlugin) + }{ + { + name: "no plugins", + plugins: []PrepareDataPlugin{}, + }, + { + name: "simple linear dependency (A -> B -> C)", + plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC}, + checkFunc: func(t *testing.T, plugins []PrepareDataPlugin) { + pA := plugins[0].(*dagTestPlugin) + pB := plugins[1].(*dagTestPlugin) + pC := plugins[2].(*dagTestPlugin) + + assert.True(t, pA.executed, "Plugin A should have been executed") + assert.True(t, pB.executed, "Plugin B should have been executed") + assert.True(t, pC.executed, "Plugin C should have been executed") + + assert.True(t, pB.execTime.After(pA.execTime), "Plugin B should execute after A") + assert.True(t, pC.execTime.After(pB.execTime), "Plugin C should execute after B") + }, + }, + { + name: "DAG with multiple dependencies (A -> B, A -> D) and one independent (E)", + plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE}, + checkFunc: func(t *testing.T, plugins []PrepareDataPlugin) { + pA := plugins[0].(*dagTestPlugin) + pB := plugins[1].(*dagTestPlugin) + pD := plugins[2].(*dagTestPlugin) + pE := plugins[3].(*dagTestPlugin) + + assert.True(t, pA.executed, "Plugin A should have been executed") + assert.True(t, pB.executed, "Plugin B should have been executed") + assert.True(t, pD.executed, "Plugin D should have been executed") + assert.True(t, pE.executed, "Plugin E should have been executed") + + assert.True(t, pB.execTime.After(pA.execTime), "Plugin B should execute after A") + assert.True(t, pD.execTime.After(pA.execTime), "Plugin D should execute after A") + }, + }, + { + name: "dependency fails", + plugins: []PrepareDataPlugin{pluginFail, pluginDependsOnFail}, + expectErr: true, + checkFunc: func(t *testing.T, plugins []PrepareDataPlugin) { + pF := plugins[0].(*dagTestPlugin) + pDOF := plugins[1].(*dagTestPlugin) + + assert.True(t, pF.executed, "Failing plugin should have been executed") + assert.False(t, pDOF.executed, "Plugin depending on fail should not be executed") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Reset execution state for plugins + for _, p := range tc.plugins { + plugin := p.(*dagTestPlugin) + plugin.executed = false + plugin.execTime = time.Time{} + } + + err := executePluginsAsDAG(tc.plugins, context.Background(), &schedulingtypes.LLMRequest{}, nil) + + if tc.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + if tc.checkFunc != nil { + tc.checkFunc(t, tc.plugins) + } + }) + } +} From d89c4cd2809eb078b4f0b944fb825fd932246835 Mon Sep 17 00:00:00 2001 From: Rahul Gurnani Date: Tue, 18 Nov 2025 23:06:09 +0000 Subject: [PATCH 3/5] Fix runner after rebase --- cmd/epp/runner/runner.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index 55a6c37b1..eeaa6d34f 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -461,7 +461,7 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...) // Check prepare data plugins for cycles. if r.requestControlConfig.ValidatePrepareDataPlugins() != nil { - return errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies") + return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies") } // Handler deprecated configuration options From d225c3cc9f40372ca2cc632e97ffa612d5cd2b83 Mon Sep 17 00:00:00 2001 From: Rahul Gurnani Date: Wed, 19 Nov 2025 21:20:41 +0000 Subject: [PATCH 4/5] Cache DAG to avoid recomputation --- cmd/epp/runner/runner.go | 2 +- .../requestcontrol/request_control_config.go | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index eeaa6d34f..adfbe6851 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -460,7 +460,7 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf // Add requestControl plugins r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...) // Check prepare data plugins for cycles. - if r.requestControlConfig.ValidatePrepareDataPlugins() != nil { + if _, err := r.requestControlConfig.PrepareDataPluginGraph(); err != nil { return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies") } diff --git a/pkg/epp/requestcontrol/request_control_config.go b/pkg/epp/requestcontrol/request_control_config.go index f29fe5582..c84080ab0 100644 --- a/pkg/epp/requestcontrol/request_control_config.go +++ b/pkg/epp/requestcontrol/request_control_config.go @@ -40,6 +40,7 @@ type Config struct { responseReceivedPlugins []ResponseReceived responseStreamingPlugins []ResponseStreaming responseCompletePlugins []ResponseComplete + prepareDataPluginGraph map[string][]string } // WithPreRequestPlugins sets the given plugins as the PreRequest plugins. @@ -105,11 +106,16 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) { } } -// ValidatePrepareDataPlugins validates the PrepareData plugins in the Config. -// It builds the data dependency graph and checks for cycles. +// PrepareDataPluginGraph creates and returns the data dependency graph of PrepareData plugins. // If a cycle is detected, it returns an error. -func (c *Config) ValidatePrepareDataPlugins() error { - _, err := prepareDataGraph(c.prepareDataPlugins) - - return err +func (c *Config) PrepareDataPluginGraph() (map[string][]string, error) { + if c.prepareDataPluginGraph != nil { + return c.prepareDataPluginGraph, nil + } + graph, err := prepareDataGraph(c.prepareDataPlugins) + if err != nil { + return nil, err + } + c.prepareDataPluginGraph = graph + return graph, nil } From db79ebf096fb5ac78a5c9ad9be6ac47338a4aef2 Mon Sep 17 00:00:00 2001 From: Rahul Gurnani Date: Thu, 20 Nov 2025 22:09:23 +0000 Subject: [PATCH 5/5] Make plugin execution sequential in topologically sorted order --- cmd/epp/runner/runner.go | 4 +- pkg/epp/requestcontrol/dag.go | 91 ++++++++++++------- pkg/epp/requestcontrol/dag_test.go | 37 +++++--- pkg/epp/requestcontrol/plugin_executor.go | 59 +----------- .../requestcontrol/request_control_config.go | 17 ++-- 5 files changed, 96 insertions(+), 112 deletions(-) diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index adfbe6851..8a35f0196 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -459,8 +459,8 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf // Add requestControl plugins r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...) - // Check prepare data plugins for cycles. - if _, err := r.requestControlConfig.PrepareDataPluginGraph(); err != nil { + // Sort prepare data plugins in DAG order (topological sort). Also check prepare data plugins for cycles. + if r.requestControlConfig.PrepareDataPluginGraph() != nil { return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies") } diff --git a/pkg/epp/requestcontrol/dag.go b/pkg/epp/requestcontrol/dag.go index 891e36e85..a0fdd1c5f 100644 --- a/pkg/epp/requestcontrol/dag.go +++ b/pkg/epp/requestcontrol/dag.go @@ -16,7 +16,10 @@ limitations under the License. package requestcontrol -import "errors" +import ( + "errors" + "slices" +) // buildDAG builds a dependency graph among data preparation plugins based on their // produced and consumed data keys. @@ -50,50 +53,76 @@ func buildDAG(plugins []PrepareDataPlugin) map[string][]string { return dag } -// prepareDataGraph builds a DAG of data preparation plugins and checks for cycles. +// prepareDataGraph builds the dependency graph and returns the plugins ordered in topological order. // If there is a cycle, it returns an error. -func prepareDataGraph(plugins []PrepareDataPlugin) (map[string][]string, error) { +func prepareDataGraph(plugins []PrepareDataPlugin) (map[string][]string, []PrepareDataPlugin, error) { dag := buildDAG(plugins) - - // Check for cycles in the DAG. - if cycleExistsInDAG(dag) { - return nil, errors.New("cycle detected in data preparation plugin dependencies") + nameToPlugin := map[string]PrepareDataPlugin{} + for _, plugin := range plugins { + nameToPlugin[plugin.TypedName().String()] = plugin + } + sortedPlugins, err := topologicalSort(dag) + if err != nil { + return nil, nil, err + } + orderedPlugins := []PrepareDataPlugin{} + for _, pluginName := range sortedPlugins { + orderedPlugins = append(orderedPlugins, nameToPlugin[pluginName]) } - return dag, nil + return dag, orderedPlugins, err } -// cycleExistsInDAG checks if there are cycles in the given directed graph represented as an adjacency list. -func cycleExistsInDAG(dag map[string][]string) bool { - visited := make(map[string]bool) - recStack := make(map[string]bool) +// TopologicalSort performs Kahn's Algorithm on a DAG. +// It returns the sorted order or an error if a cycle is detected. +func topologicalSort(graph map[string][]string) ([]string, error) { + // 1. Initialize in-degree map + inDegree := make(map[string]int) - var dfs func(string) bool - dfs = func(node string) bool { - if recStack[node] { - return true // Cycle detected + // Ensure all nodes are present in the inDegree map, even those with no dependencies + for u, neighbors := range graph { + if _, ok := inDegree[u]; !ok { + inDegree[u] = 0 } - if visited[node] { - return false + for _, v := range neighbors { + inDegree[v]++ // Increment in-degree for the destination node } - visited[node] = true - recStack[node] = true + } - for _, neighbor := range dag[node] { - if dfs(neighbor) { - return true - } + // 2. Initialize the queue with nodes having 0 in-degree + var queue []string + for node, degree := range inDegree { + if degree == 0 { + queue = append(queue, node) } - recStack[node] = false - return false } - for pluginName := range dag { - if !visited[pluginName] { - if dfs(pluginName) { - return true + var result []string + + // 3. Process the queue + for len(queue) > 0 { + // Dequeue + u := queue[0] + queue = queue[1:] + + result = append(result, u) + + // Decrease in-degree of neighbors + if neighbors, ok := graph[u]; ok { + for _, v := range neighbors { + inDegree[v]-- + if inDegree[v] == 0 { + queue = append(queue, v) + } } } } - return false + + // 4. Check for cycles + // If the result size != total nodes, there is a cycle + if len(result) != len(inDegree) { + return nil, errors.New("cycle detected: graph is not a DAG") + } + slices.Reverse(result) + return result, nil } diff --git a/pkg/epp/requestcontrol/dag_test.go b/pkg/epp/requestcontrol/dag_test.go index a62bd3793..c38df188d 100644 --- a/pkg/epp/requestcontrol/dag_test.go +++ b/pkg/epp/requestcontrol/dag_test.go @@ -18,6 +18,7 @@ package requestcontrol import ( "context" + "maps" "testing" "github.com/google/go-cmp/cmp" @@ -82,7 +83,7 @@ func TestPrepareDataGraph(t *testing.T) { expectError: false, }, { - name: "Simple linear dependency (A -> B -> C)", + name: "Simple linear dependency (C -> B -> A)", plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC}, expectedDAG: map[string][]string{ "A/mock": {}, @@ -92,7 +93,7 @@ func TestPrepareDataGraph(t *testing.T) { expectError: false, }, { - name: "DAG with multiple dependencies (A -> B, A -> D)", + name: "DAG with multiple dependencies (B -> A, D -> A, E independent)", plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE}, expectedDAG: map[string][]string{ "A/mock": {}, @@ -108,17 +109,11 @@ func TestPrepareDataGraph(t *testing.T) { expectedDAG: nil, expectError: true, }, - { - name: "Complex graph with a cycle", - plugins: []PrepareDataPlugin{pluginA, pluginB, pluginX, pluginY}, - expectedDAG: nil, - expectError: true, - }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - dag, err := prepareDataGraph(tc.plugins) + dag, orderedPlugins, err := prepareDataGraph(tc.plugins) if tc.expectError { assert.Error(t, err) @@ -129,9 +124,7 @@ func TestPrepareDataGraph(t *testing.T) { // Normalize the slices in the maps for consistent comparison normalizedDAG := make(map[string][]string) - for k, v := range dag { - normalizedDAG[k] = v - } + maps.Copy(normalizedDAG, dag) normalizedExpectedDAG := make(map[string][]string) for k, v := range tc.expectedDAG { normalizedExpectedDAG[k] = v @@ -140,7 +133,27 @@ func TestPrepareDataGraph(t *testing.T) { if diff := cmp.Diff(normalizedExpectedDAG, normalizedDAG); diff != "" { t.Errorf("prepareDataGraph() mismatch (-want +got):\n%s", diff) } + + orderedPluginNames := make([]string, len(orderedPlugins)) + for i, p := range orderedPlugins { + orderedPluginNames[i] = p.TypedName().String() + } + assertTopologicalOrder(t, dag, orderedPlugins) } }) } } + +func assertTopologicalOrder(t *testing.T, dag map[string][]string, ordered []PrepareDataPlugin) { + t.Helper() + positions := make(map[string]int) + for i, p := range ordered { + positions[p.TypedName().String()] = i + } + + for node, dependencies := range dag { + for _, dep := range dependencies { + assert.Less(t, positions[dep], positions[node], "Dependency %s should come before %s", dep, node) + } + } +} diff --git a/pkg/epp/requestcontrol/plugin_executor.go b/pkg/epp/requestcontrol/plugin_executor.go index 1048be497..7771c4f67 100644 --- a/pkg/epp/requestcontrol/plugin_executor.go +++ b/pkg/epp/requestcontrol/plugin_executor.go @@ -19,7 +19,6 @@ package requestcontrol import ( "context" "errors" - "fmt" "time" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" @@ -29,63 +28,9 @@ import ( // So, a plugin is executed only after all its dependencies have been executed. // If there is a cycle or any plugin fails with error, it returns an error. func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { - // Build the DAG - // The error validation happens on startup when loading the config. So, here there should not be any error. - dag, err := prepareDataGraph(plugins) - if err != nil { - return err - } - // Create a readonly map of plugin name to plugin instance. - nameToNode := map[string]PrepareDataPlugin{} for _, plugin := range plugins { - nameToNode[plugin.TypedName().String()] = plugin - } - // Execute the DAG - - // Channels to signal plugin execution completion. - pluginExecuted := make(map[string]chan error) - // The capacity of the channel is equal to the number of dependents + 1 (for itself). - capacityMap := make(map[string]int) - for pluginName := range dag { - capacityMap[pluginName]++ - for _, dep := range dag[pluginName] { - capacityMap[dep]++ - } - } - for pluginName, capacity := range capacityMap { - pluginExecuted[pluginName] = make(chan error, capacity) - } - for pluginName, dependents := range dag { - // Execute plugins based on dependencies. - // Wait for the dependencies to complete before executing a plugin. - go func(pN string, depds []string) { - for _, dep := range depds { - // Wait for the dependency plugin to signal completion. - if err, closed := <-pluginExecuted[dep]; closed { - if err != nil { - for range cap(pluginExecuted[pN]) { - // Notify all dependents about the failure. - pluginExecuted[pN] <- fmt.Errorf("dependency plugin %s failed: %w", dep, err) - } - // Do not execute this plugin as one of its dependencies failed. - return - } - } - } - res := nameToNode[pN].PrepareRequestData(ctx, request, pods) - for range cap(pluginExecuted[pN]) { - // Notify all dependents about the completion. - pluginExecuted[pN] <- res - } - }(pluginName, dependents) - } - - // Check for errors in plugin execution. - // This will also ensure that all plugins have completed execution before returning. - for pluginName := range dag { - err := <-pluginExecuted[pluginName] - if err != nil { - return fmt.Errorf("prepare data plugin %s failed: %v", pluginName, err) + if err := plugin.PrepareRequestData(ctx, request, pods); err != nil { + return errors.New("prepare data plugin " + plugin.TypedName().String() + " failed: " + err.Error()) } } return nil diff --git a/pkg/epp/requestcontrol/request_control_config.go b/pkg/epp/requestcontrol/request_control_config.go index c84080ab0..1b00480bd 100644 --- a/pkg/epp/requestcontrol/request_control_config.go +++ b/pkg/epp/requestcontrol/request_control_config.go @@ -40,7 +40,6 @@ type Config struct { responseReceivedPlugins []ResponseReceived responseStreamingPlugins []ResponseStreaming responseCompletePlugins []ResponseComplete - prepareDataPluginGraph map[string][]string } // WithPreRequestPlugins sets the given plugins as the PreRequest plugins. @@ -106,16 +105,14 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) { } } -// PrepareDataPluginGraph creates and returns the data dependency graph of PrepareData plugins. +// PrepareDataPluginGraph creates data dependency graph and sorts the plugins in topological order. // If a cycle is detected, it returns an error. -func (c *Config) PrepareDataPluginGraph() (map[string][]string, error) { - if c.prepareDataPluginGraph != nil { - return c.prepareDataPluginGraph, nil - } - graph, err := prepareDataGraph(c.prepareDataPlugins) +func (c *Config) PrepareDataPluginGraph() error { + _, plugins, err := prepareDataGraph(c.prepareDataPlugins) if err != nil { - return nil, err + return err } - c.prepareDataPluginGraph = graph - return graph, nil + c.prepareDataPlugins = plugins + + return nil }