Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,10 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf

// Add requestControl plugins
r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...)
// Sort prepare data plugins in DAG order (topological sort). Also check prepare data plugins for cycles.
if r.requestControlConfig.PrepareDataPluginGraph() != nil {
return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies")
}

// Handler deprecated configuration options
r.deprecatedConfigurationHelper(cfg, logger)
Expand Down
128 changes: 128 additions & 0 deletions pkg/epp/requestcontrol/dag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package requestcontrol

import (
"errors"
"slices"
)

// buildDAG builds a dependency graph among data preparation plugins based on their
// produced and consumed data keys.
func buildDAG(plugins []PrepareDataPlugin) map[string][]string {
dag := make(map[string][]string)
for _, plugin := range plugins {
dag[plugin.TypedName().String()] = []string{}
}
// Create dependency graph as a DAG.
for i := range plugins {
for j := range plugins {
if i == j {
continue
}
// Check whether plugin[i] produces something consumed by plugin[j]. In that case, j depends on i.
if plugins[i].Produces() != nil && plugins[j].Consumes() != nil {
// For all the keys produced by plugin i, check if plugin j consumes any of them.
// If yes, then j depends on i.
for producedKey := range plugins[i].Produces() {
// If plugin j consumes the produced key, then j depends on i. We can break after the first match.
if _, ok := plugins[j].Consumes()[producedKey]; ok {
iPluginName := plugins[i].TypedName().String()
jPluginName := plugins[j].TypedName().String()
dag[jPluginName] = append(dag[jPluginName], iPluginName)
break
}
}
}
}
}
return dag
}

// prepareDataGraph builds the dependency graph and returns the plugins ordered in topological order.
// If there is a cycle, it returns an error.
func prepareDataGraph(plugins []PrepareDataPlugin) (map[string][]string, []PrepareDataPlugin, error) {
dag := buildDAG(plugins)
nameToPlugin := map[string]PrepareDataPlugin{}
for _, plugin := range plugins {
nameToPlugin[plugin.TypedName().String()] = plugin
}
sortedPlugins, err := topologicalSort(dag)
if err != nil {
return nil, nil, err
}
orderedPlugins := []PrepareDataPlugin{}
for _, pluginName := range sortedPlugins {
orderedPlugins = append(orderedPlugins, nameToPlugin[pluginName])
}

return dag, orderedPlugins, err
}

// TopologicalSort performs Kahn's Algorithm on a DAG.
// It returns the sorted order or an error if a cycle is detected.
func topologicalSort(graph map[string][]string) ([]string, error) {
// 1. Initialize in-degree map
inDegree := make(map[string]int)

// Ensure all nodes are present in the inDegree map, even those with no dependencies
for u, neighbors := range graph {
if _, ok := inDegree[u]; !ok {
inDegree[u] = 0
}
for _, v := range neighbors {
inDegree[v]++ // Increment in-degree for the destination node
}
}

// 2. Initialize the queue with nodes having 0 in-degree
var queue []string
for node, degree := range inDegree {
if degree == 0 {
queue = append(queue, node)
}
}

var result []string

// 3. Process the queue
for len(queue) > 0 {
// Dequeue
u := queue[0]
queue = queue[1:]

result = append(result, u)

// Decrease in-degree of neighbors
if neighbors, ok := graph[u]; ok {
for _, v := range neighbors {
inDegree[v]--
if inDegree[v] == 0 {
queue = append(queue, v)
}
}
}
}

// 4. Check for cycles
// If the result size != total nodes, there is a cycle
if len(result) != len(inDegree) {
return nil, errors.New("cycle detected: graph is not a DAG")
}
slices.Reverse(result)
return result, nil
}
159 changes: 159 additions & 0 deletions pkg/epp/requestcontrol/dag_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package requestcontrol

import (
"context"
"maps"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

type mockPrepareRequestDataP struct {
name string
produces map[string]any
consumes map[string]any
}

func (m *mockPrepareRequestDataP) TypedName() plugins.TypedName {
return plugins.TypedName{Name: m.name, Type: "mock"}
}

func (m *mockPrepareRequestDataP) Produces() map[string]any {
return m.produces
}

func (m *mockPrepareRequestDataP) Consumes() map[string]any {
return m.consumes
}

func (m *mockPrepareRequestDataP) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error {
pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42})
return nil
}

func TestPrepareDataGraph(t *testing.T) {
pluginA := &mockPrepareRequestDataP{name: "A", produces: map[string]any{"keyA": nil}}
pluginB := &mockPrepareRequestDataP{name: "B", consumes: map[string]any{"keyA": nil}, produces: map[string]any{"keyB": nil}}
pluginC := &mockPrepareRequestDataP{name: "C", consumes: map[string]any{"keyB": nil}}
pluginD := &mockPrepareRequestDataP{name: "D", consumes: map[string]any{"keyA": nil}}
pluginE := &mockPrepareRequestDataP{name: "E"} // No dependencies

// Cycle plugins
pluginX := &mockPrepareRequestDataP{name: "X", produces: map[string]any{"keyX": nil}, consumes: map[string]any{"keyY": nil}}
pluginY := &mockPrepareRequestDataP{name: "Y", produces: map[string]any{"keyY": nil}, consumes: map[string]any{"keyX": nil}}

testCases := []struct {
name string
plugins []PrepareDataPlugin
expectedDAG map[string][]string
expectError bool
}{
{
name: "No plugins",
plugins: []PrepareDataPlugin{},
expectedDAG: map[string][]string{},
expectError: false,
},
{
name: "Plugins with no dependencies",
plugins: []PrepareDataPlugin{pluginA, pluginE},
expectedDAG: map[string][]string{
"A/mock": {},
"E/mock": {},
},
expectError: false,
},
{
name: "Simple linear dependency (C -> B -> A)",
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC},
expectedDAG: map[string][]string{
"A/mock": {},
"B/mock": {"A/mock"},
"C/mock": {"B/mock"},
},
expectError: false,
},
{
name: "DAG with multiple dependencies (B -> A, D -> A, E independent)",
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE},
expectedDAG: map[string][]string{
"A/mock": {},
"B/mock": {"A/mock"},
"D/mock": {"A/mock"},
"E/mock": {},
},
expectError: false,
},
{
name: "Graph with a cycle (X -> Y, Y -> X)",
plugins: []PrepareDataPlugin{pluginX, pluginY},
expectedDAG: nil,
expectError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dag, orderedPlugins, err := prepareDataGraph(tc.plugins)

if tc.expectError {
assert.Error(t, err)
assert.Nil(t, dag)
assert.Contains(t, err.Error(), "cycle detected")
} else {
assert.NoError(t, err)

// Normalize the slices in the maps for consistent comparison
normalizedDAG := make(map[string][]string)
maps.Copy(normalizedDAG, dag)
normalizedExpectedDAG := make(map[string][]string)
for k, v := range tc.expectedDAG {
normalizedExpectedDAG[k] = v
}

if diff := cmp.Diff(normalizedExpectedDAG, normalizedDAG); diff != "" {
t.Errorf("prepareDataGraph() mismatch (-want +got):\n%s", diff)
}

orderedPluginNames := make([]string, len(orderedPlugins))
for i, p := range orderedPlugins {
orderedPluginNames[i] = p.TypedName().String()
}
assertTopologicalOrder(t, dag, orderedPlugins)
}
})
}
}

func assertTopologicalOrder(t *testing.T, dag map[string][]string, ordered []PrepareDataPlugin) {
t.Helper()
positions := make(map[string]int)
for i, p := range ordered {
positions[p.TypedName().String()] = i
}

for node, dependencies := range dag {
for _, dep := range dependencies {
assert.Less(t, positions[dep], positions[node], "Dependency %s should come before %s", dep, node)
}
}
}
6 changes: 1 addition & 5 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,9 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
}
}

// TODO: Execute plugins in parallel once DAG execution is supported.
// runPrepareDataPlugins executes PrepareDataPlugins sequentially.
func (d *Director) runPrepareDataPlugins(ctx context.Context,
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
return prepareDataPluginsWithTimeout(
prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods)

return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods)
}

func (d *Director) runAdmissionPlugins(ctx context.Context,
Expand Down
22 changes: 13 additions & 9 deletions pkg/epp/requestcontrol/plugin_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,24 @@ import (
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// executePluginsAsDAG executes PrepareData plugins as a DAG based on their dependencies asynchronously.
// So, a plugin is executed only after all its dependencies have been executed.
// If there is a cycle or any plugin fails with error, it returns an error.
func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
for _, plugin := range plugins {
if err := plugin.PrepareRequestData(ctx, request, pods); err != nil {
return errors.New("prepare data plugin " + plugin.TypedName().String() + " failed: " + err.Error())
}
}
return nil
}

// prepareDataPluginsWithTimeout executes the PrepareRequestData plugins with retries and timeout.
func prepareDataPluginsWithTimeout(timeout time.Duration, plugins []PrepareDataPlugin,
ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
errCh := make(chan error, 1)
// Execute plugins sequentially in a separate goroutine
go func() {
for _, plugin := range plugins {
err := plugin.PrepareRequestData(ctx, request, pods)
if err != nil {
errCh <- errors.New("prepare data plugin " + plugin.TypedName().String() + " failed: " + err.Error())
return
}
}
errCh <- nil
errCh <- executePluginsAsDAG(plugins, ctx, request, pods)
}()

select {
Expand Down
Loading