diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index e6caafa2d..8a35f0196 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -459,6 +459,10 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf // Add requestControl plugins r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...) + // Sort prepare data plugins in DAG order (topological sort). Also check prepare data plugins for cycles. + if r.requestControlConfig.PrepareDataPluginGraph() != nil { + return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies") + } // Handler deprecated configuration options r.deprecatedConfigurationHelper(cfg, logger) diff --git a/pkg/epp/requestcontrol/dag.go b/pkg/epp/requestcontrol/dag.go new file mode 100644 index 000000000..a0fdd1c5f --- /dev/null +++ b/pkg/epp/requestcontrol/dag.go @@ -0,0 +1,128 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package requestcontrol + +import ( + "errors" + "slices" +) + +// buildDAG builds a dependency graph among data preparation plugins based on their +// produced and consumed data keys. +func buildDAG(plugins []PrepareDataPlugin) map[string][]string { + dag := make(map[string][]string) + for _, plugin := range plugins { + dag[plugin.TypedName().String()] = []string{} + } + // Create dependency graph as a DAG. + for i := range plugins { + for j := range plugins { + if i == j { + continue + } + // Check whether plugin[i] produces something consumed by plugin[j]. In that case, j depends on i. + if plugins[i].Produces() != nil && plugins[j].Consumes() != nil { + // For all the keys produced by plugin i, check if plugin j consumes any of them. + // If yes, then j depends on i. + for producedKey := range plugins[i].Produces() { + // If plugin j consumes the produced key, then j depends on i. We can break after the first match. + if _, ok := plugins[j].Consumes()[producedKey]; ok { + iPluginName := plugins[i].TypedName().String() + jPluginName := plugins[j].TypedName().String() + dag[jPluginName] = append(dag[jPluginName], iPluginName) + break + } + } + } + } + } + return dag +} + +// prepareDataGraph builds the dependency graph and returns the plugins ordered in topological order. +// If there is a cycle, it returns an error. +func prepareDataGraph(plugins []PrepareDataPlugin) (map[string][]string, []PrepareDataPlugin, error) { + dag := buildDAG(plugins) + nameToPlugin := map[string]PrepareDataPlugin{} + for _, plugin := range plugins { + nameToPlugin[plugin.TypedName().String()] = plugin + } + sortedPlugins, err := topologicalSort(dag) + if err != nil { + return nil, nil, err + } + orderedPlugins := []PrepareDataPlugin{} + for _, pluginName := range sortedPlugins { + orderedPlugins = append(orderedPlugins, nameToPlugin[pluginName]) + } + + return dag, orderedPlugins, err +} + +// TopologicalSort performs Kahn's Algorithm on a DAG. +// It returns the sorted order or an error if a cycle is detected. +func topologicalSort(graph map[string][]string) ([]string, error) { + // 1. Initialize in-degree map + inDegree := make(map[string]int) + + // Ensure all nodes are present in the inDegree map, even those with no dependencies + for u, neighbors := range graph { + if _, ok := inDegree[u]; !ok { + inDegree[u] = 0 + } + for _, v := range neighbors { + inDegree[v]++ // Increment in-degree for the destination node + } + } + + // 2. Initialize the queue with nodes having 0 in-degree + var queue []string + for node, degree := range inDegree { + if degree == 0 { + queue = append(queue, node) + } + } + + var result []string + + // 3. Process the queue + for len(queue) > 0 { + // Dequeue + u := queue[0] + queue = queue[1:] + + result = append(result, u) + + // Decrease in-degree of neighbors + if neighbors, ok := graph[u]; ok { + for _, v := range neighbors { + inDegree[v]-- + if inDegree[v] == 0 { + queue = append(queue, v) + } + } + } + } + + // 4. Check for cycles + // If the result size != total nodes, there is a cycle + if len(result) != len(inDegree) { + return nil, errors.New("cycle detected: graph is not a DAG") + } + slices.Reverse(result) + return result, nil +} diff --git a/pkg/epp/requestcontrol/dag_test.go b/pkg/epp/requestcontrol/dag_test.go new file mode 100644 index 000000000..c38df188d --- /dev/null +++ b/pkg/epp/requestcontrol/dag_test.go @@ -0,0 +1,159 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package requestcontrol + +import ( + "context" + "maps" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +type mockPrepareRequestDataP struct { + name string + produces map[string]any + consumes map[string]any +} + +func (m *mockPrepareRequestDataP) TypedName() plugins.TypedName { + return plugins.TypedName{Name: m.name, Type: "mock"} +} + +func (m *mockPrepareRequestDataP) Produces() map[string]any { + return m.produces +} + +func (m *mockPrepareRequestDataP) Consumes() map[string]any { + return m.consumes +} + +func (m *mockPrepareRequestDataP) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error { + pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42}) + return nil +} + +func TestPrepareDataGraph(t *testing.T) { + pluginA := &mockPrepareRequestDataP{name: "A", produces: map[string]any{"keyA": nil}} + pluginB := &mockPrepareRequestDataP{name: "B", consumes: map[string]any{"keyA": nil}, produces: map[string]any{"keyB": nil}} + pluginC := &mockPrepareRequestDataP{name: "C", consumes: map[string]any{"keyB": nil}} + pluginD := &mockPrepareRequestDataP{name: "D", consumes: map[string]any{"keyA": nil}} + pluginE := &mockPrepareRequestDataP{name: "E"} // No dependencies + + // Cycle plugins + pluginX := &mockPrepareRequestDataP{name: "X", produces: map[string]any{"keyX": nil}, consumes: map[string]any{"keyY": nil}} + pluginY := &mockPrepareRequestDataP{name: "Y", produces: map[string]any{"keyY": nil}, consumes: map[string]any{"keyX": nil}} + + testCases := []struct { + name string + plugins []PrepareDataPlugin + expectedDAG map[string][]string + expectError bool + }{ + { + name: "No plugins", + plugins: []PrepareDataPlugin{}, + expectedDAG: map[string][]string{}, + expectError: false, + }, + { + name: "Plugins with no dependencies", + plugins: []PrepareDataPlugin{pluginA, pluginE}, + expectedDAG: map[string][]string{ + "A/mock": {}, + "E/mock": {}, + }, + expectError: false, + }, + { + name: "Simple linear dependency (C -> B -> A)", + plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC}, + expectedDAG: map[string][]string{ + "A/mock": {}, + "B/mock": {"A/mock"}, + "C/mock": {"B/mock"}, + }, + expectError: false, + }, + { + name: "DAG with multiple dependencies (B -> A, D -> A, E independent)", + plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE}, + expectedDAG: map[string][]string{ + "A/mock": {}, + "B/mock": {"A/mock"}, + "D/mock": {"A/mock"}, + "E/mock": {}, + }, + expectError: false, + }, + { + name: "Graph with a cycle (X -> Y, Y -> X)", + plugins: []PrepareDataPlugin{pluginX, pluginY}, + expectedDAG: nil, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dag, orderedPlugins, err := prepareDataGraph(tc.plugins) + + if tc.expectError { + assert.Error(t, err) + assert.Nil(t, dag) + assert.Contains(t, err.Error(), "cycle detected") + } else { + assert.NoError(t, err) + + // Normalize the slices in the maps for consistent comparison + normalizedDAG := make(map[string][]string) + maps.Copy(normalizedDAG, dag) + normalizedExpectedDAG := make(map[string][]string) + for k, v := range tc.expectedDAG { + normalizedExpectedDAG[k] = v + } + + if diff := cmp.Diff(normalizedExpectedDAG, normalizedDAG); diff != "" { + t.Errorf("prepareDataGraph() mismatch (-want +got):\n%s", diff) + } + + orderedPluginNames := make([]string, len(orderedPlugins)) + for i, p := range orderedPlugins { + orderedPluginNames[i] = p.TypedName().String() + } + assertTopologicalOrder(t, dag, orderedPlugins) + } + }) + } +} + +func assertTopologicalOrder(t *testing.T, dag map[string][]string, ordered []PrepareDataPlugin) { + t.Helper() + positions := make(map[string]int) + for i, p := range ordered { + positions[p.TypedName().String()] = i + } + + for node, dependencies := range dag { + for _, dep := range dependencies { + assert.Less(t, positions[dep], positions[node], "Dependency %s should come before %s", dep, node) + } + } +} diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index e8db59b7d..4c35627b7 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -353,13 +353,9 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling } } -// TODO: Execute plugins in parallel once DAG execution is supported. -// runPrepareDataPlugins executes PrepareDataPlugins sequentially. func (d *Director) runPrepareDataPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { - return prepareDataPluginsWithTimeout( - prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods) - + return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods) } func (d *Director) runAdmissionPlugins(ctx context.Context, diff --git a/pkg/epp/requestcontrol/plugin_executor.go b/pkg/epp/requestcontrol/plugin_executor.go index be8d4ba39..7771c4f67 100644 --- a/pkg/epp/requestcontrol/plugin_executor.go +++ b/pkg/epp/requestcontrol/plugin_executor.go @@ -24,20 +24,24 @@ import ( schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) +// executePluginsAsDAG executes PrepareData plugins as a DAG based on their dependencies asynchronously. +// So, a plugin is executed only after all its dependencies have been executed. +// If there is a cycle or any plugin fails with error, it returns an error. +func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + for _, plugin := range plugins { + if err := plugin.PrepareRequestData(ctx, request, pods); err != nil { + return errors.New("prepare data plugin " + plugin.TypedName().String() + " failed: " + err.Error()) + } + } + return nil +} + // prepareDataPluginsWithTimeout executes the PrepareRequestData plugins with retries and timeout. func prepareDataPluginsWithTimeout(timeout time.Duration, plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { errCh := make(chan error, 1) - // Execute plugins sequentially in a separate goroutine go func() { - for _, plugin := range plugins { - err := plugin.PrepareRequestData(ctx, request, pods) - if err != nil { - errCh <- errors.New("prepare data plugin " + plugin.TypedName().String() + " failed: " + err.Error()) - return - } - } - errCh <- nil + errCh <- executePluginsAsDAG(plugins, ctx, request, pods) }() select { diff --git a/pkg/epp/requestcontrol/plugin_executor_test.go b/pkg/epp/requestcontrol/plugin_executor_test.go index b37264f6b..1825f7a4c 100644 --- a/pkg/epp/requestcontrol/plugin_executor_test.go +++ b/pkg/epp/requestcontrol/plugin_executor_test.go @@ -19,6 +19,7 @@ package requestcontrol import ( "context" "errors" + "sync" "testing" "time" @@ -157,3 +158,138 @@ func TestPrepareDataPluginsWithTimeout(t *testing.T) { }) } } + +type dagTestPlugin struct { + mockPrepareRequestDataPlugin + produces map[string]any + consumes map[string]any + execTime time.Time + mu sync.Mutex +} + +func (p *dagTestPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + p.mu.Lock() + defer p.mu.Unlock() + p.execTime = time.Now() + return p.mockPrepareRequestDataPlugin.PrepareRequestData(ctx, request, pods) +} + +func (p *dagTestPlugin) Produces() map[string]any { + return p.produces +} + +func (p *dagTestPlugin) Consumes() map[string]any { + return p.consumes +} + +func TestExecutePluginsAsDAG(t *testing.T) { + pluginA := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "A", delay: 20 * time.Millisecond}, + produces: map[string]any{"keyA": nil}, + } + pluginB := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "B"}, + consumes: map[string]any{"keyA": nil}, + produces: map[string]any{"keyB": nil}, + } + pluginC := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "C"}, + consumes: map[string]any{"keyB": nil}, + } + pluginD := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "D"}, + consumes: map[string]any{"keyA": nil}, + } + pluginE := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "E"}, + } + pluginFail := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "Fail", returnErr: errors.New("plugin failed")}, + produces: map[string]any{"keyFail": nil}, + } + pluginDependsOnFail := &dagTestPlugin{ + mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "DependsOnFail"}, + consumes: map[string]any{"keyFail": nil}, + } + + testCases := []struct { + name string + plugins []PrepareDataPlugin + expectErr bool + checkFunc func(t *testing.T, plugins []PrepareDataPlugin) + }{ + { + name: "no plugins", + plugins: []PrepareDataPlugin{}, + }, + { + name: "simple linear dependency (A -> B -> C)", + plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC}, + checkFunc: func(t *testing.T, plugins []PrepareDataPlugin) { + pA := plugins[0].(*dagTestPlugin) + pB := plugins[1].(*dagTestPlugin) + pC := plugins[2].(*dagTestPlugin) + + assert.True(t, pA.executed, "Plugin A should have been executed") + assert.True(t, pB.executed, "Plugin B should have been executed") + assert.True(t, pC.executed, "Plugin C should have been executed") + + assert.True(t, pB.execTime.After(pA.execTime), "Plugin B should execute after A") + assert.True(t, pC.execTime.After(pB.execTime), "Plugin C should execute after B") + }, + }, + { + name: "DAG with multiple dependencies (A -> B, A -> D) and one independent (E)", + plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE}, + checkFunc: func(t *testing.T, plugins []PrepareDataPlugin) { + pA := plugins[0].(*dagTestPlugin) + pB := plugins[1].(*dagTestPlugin) + pD := plugins[2].(*dagTestPlugin) + pE := plugins[3].(*dagTestPlugin) + + assert.True(t, pA.executed, "Plugin A should have been executed") + assert.True(t, pB.executed, "Plugin B should have been executed") + assert.True(t, pD.executed, "Plugin D should have been executed") + assert.True(t, pE.executed, "Plugin E should have been executed") + + assert.True(t, pB.execTime.After(pA.execTime), "Plugin B should execute after A") + assert.True(t, pD.execTime.After(pA.execTime), "Plugin D should execute after A") + }, + }, + { + name: "dependency fails", + plugins: []PrepareDataPlugin{pluginFail, pluginDependsOnFail}, + expectErr: true, + checkFunc: func(t *testing.T, plugins []PrepareDataPlugin) { + pF := plugins[0].(*dagTestPlugin) + pDOF := plugins[1].(*dagTestPlugin) + + assert.True(t, pF.executed, "Failing plugin should have been executed") + assert.False(t, pDOF.executed, "Plugin depending on fail should not be executed") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Reset execution state for plugins + for _, p := range tc.plugins { + plugin := p.(*dagTestPlugin) + plugin.executed = false + plugin.execTime = time.Time{} + } + + err := executePluginsAsDAG(tc.plugins, context.Background(), &schedulingtypes.LLMRequest{}, nil) + + if tc.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + if tc.checkFunc != nil { + tc.checkFunc(t, tc.plugins) + } + }) + } +} diff --git a/pkg/epp/requestcontrol/request_control_config.go b/pkg/epp/requestcontrol/request_control_config.go index 9701be999..1b00480bd 100644 --- a/pkg/epp/requestcontrol/request_control_config.go +++ b/pkg/epp/requestcontrol/request_control_config.go @@ -104,3 +104,15 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) { } } } + +// PrepareDataPluginGraph creates data dependency graph and sorts the plugins in topological order. +// If a cycle is detected, it returns an error. +func (c *Config) PrepareDataPluginGraph() error { + _, plugins, err := prepareDataGraph(c.prepareDataPlugins) + if err != nil { + return err + } + c.prepareDataPlugins = plugins + + return nil +}