Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,10 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf

// Add requestControl plugins
r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...)
// Check prepare data plugins for cycles.
if r.requestControlConfig.ValidatePrepareDataPlugins() != nil {
return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies")
}

// Handler deprecated configuration options
r.deprecatedConfigurationHelper(cfg, logger)
Expand Down
99 changes: 99 additions & 0 deletions pkg/epp/requestcontrol/dag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package requestcontrol

import "errors"

// buildDAG builds a dependency graph among data preparation plugins based on their
// produced and consumed data keys.
func buildDAG(plugins []PrepareDataPlugin) map[string][]string {
dag := make(map[string][]string)
for _, plugin := range plugins {
dag[plugin.TypedName().String()] = []string{}
}
// Create dependency graph as a DAG.
for i := range plugins {
for j := range plugins {
if i == j {
continue
}
// Check whether plugin[i] produces something consumed by plugin[j]. In that case, j depends on i.
if plugins[i].Produces() != nil && plugins[j].Consumes() != nil {
// For all the keys produced by plugin i, check if plugin j consumes any of them.
// If yes, then j depends on i.
for producedKey := range plugins[i].Produces() {
// If plugin j consumes the produced key, then j depends on i. We can break after the first match.
if _, ok := plugins[j].Consumes()[producedKey]; ok {
iPluginName := plugins[i].TypedName().String()
jPluginName := plugins[j].TypedName().String()
dag[jPluginName] = append(dag[jPluginName], iPluginName)
break
}
}
}
}
}
return dag
}

// prepareDataGraph builds a DAG of data preparation plugins and checks for cycles.
// If there is a cycle, it returns an error.
func prepareDataGraph(plugins []PrepareDataPlugin) (map[string][]string, error) {
dag := buildDAG(plugins)

// Check for cycles in the DAG.
if cycleExistsInDAG(dag) {
return nil, errors.New("cycle detected in data preparation plugin dependencies")
}

return dag, nil
}

// cycleExistsInDAG checks if there are cycles in the given directed graph represented as an adjacency list.
func cycleExistsInDAG(dag map[string][]string) bool {
visited := make(map[string]bool)
recStack := make(map[string]bool)

var dfs func(string) bool
dfs = func(node string) bool {
if recStack[node] {
return true // Cycle detected
}
if visited[node] {
return false
}
visited[node] = true
recStack[node] = true

for _, neighbor := range dag[node] {
if dfs(neighbor) {
return true
}
}
recStack[node] = false
return false
}

for pluginName := range dag {
if !visited[pluginName] {
if dfs(pluginName) {
return true
}
}
}
return false
}
146 changes: 146 additions & 0 deletions pkg/epp/requestcontrol/dag_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package requestcontrol

import (
"context"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

type mockPrepareRequestDataP struct {
name string
produces map[string]any
consumes map[string]any
}

func (m *mockPrepareRequestDataP) TypedName() plugins.TypedName {
return plugins.TypedName{Name: m.name, Type: "mock"}
}

func (m *mockPrepareRequestDataP) Produces() map[string]any {
return m.produces
}

func (m *mockPrepareRequestDataP) Consumes() map[string]any {
return m.consumes
}

func (m *mockPrepareRequestDataP) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error {
pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42})
return nil
}

func TestPrepareDataGraph(t *testing.T) {
pluginA := &mockPrepareRequestDataP{name: "A", produces: map[string]any{"keyA": nil}}
pluginB := &mockPrepareRequestDataP{name: "B", consumes: map[string]any{"keyA": nil}, produces: map[string]any{"keyB": nil}}
pluginC := &mockPrepareRequestDataP{name: "C", consumes: map[string]any{"keyB": nil}}
pluginD := &mockPrepareRequestDataP{name: "D", consumes: map[string]any{"keyA": nil}}
pluginE := &mockPrepareRequestDataP{name: "E"} // No dependencies

// Cycle plugins
pluginX := &mockPrepareRequestDataP{name: "X", produces: map[string]any{"keyX": nil}, consumes: map[string]any{"keyY": nil}}
pluginY := &mockPrepareRequestDataP{name: "Y", produces: map[string]any{"keyY": nil}, consumes: map[string]any{"keyX": nil}}

testCases := []struct {
name string
plugins []PrepareDataPlugin
expectedDAG map[string][]string
expectError bool
}{
{
name: "No plugins",
plugins: []PrepareDataPlugin{},
expectedDAG: map[string][]string{},
expectError: false,
},
{
name: "Plugins with no dependencies",
plugins: []PrepareDataPlugin{pluginA, pluginE},
expectedDAG: map[string][]string{
"A/mock": {},
"E/mock": {},
},
expectError: false,
},
{
name: "Simple linear dependency (A -> B -> C)",
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC},
expectedDAG: map[string][]string{
"A/mock": {},
"B/mock": {"A/mock"},
"C/mock": {"B/mock"},
},
expectError: false,
},
{
name: "DAG with multiple dependencies (A -> B, A -> D)",
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE},
expectedDAG: map[string][]string{
"A/mock": {},
"B/mock": {"A/mock"},
"D/mock": {"A/mock"},
"E/mock": {},
},
expectError: false,
},
{
name: "Graph with a cycle (X -> Y, Y -> X)",
plugins: []PrepareDataPlugin{pluginX, pluginY},
expectedDAG: nil,
expectError: true,
},
{
name: "Complex graph with a cycle",
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginX, pluginY},
expectedDAG: nil,
expectError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dag, err := prepareDataGraph(tc.plugins)

if tc.expectError {
assert.Error(t, err)
assert.Nil(t, dag)
assert.Contains(t, err.Error(), "cycle detected")
} else {
assert.NoError(t, err)

// Normalize the slices in the maps for consistent comparison
normalizedDAG := make(map[string][]string)
for k, v := range dag {
normalizedDAG[k] = v
}
normalizedExpectedDAG := make(map[string][]string)
for k, v := range tc.expectedDAG {
normalizedExpectedDAG[k] = v
}

if diff := cmp.Diff(normalizedExpectedDAG, normalizedDAG); diff != "" {
t.Errorf("prepareDataGraph() mismatch (-want +got):\n%s", diff)
}
}
})
}
}
6 changes: 1 addition & 5 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,9 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
}
}

// TODO: Execute plugins in parallel once DAG execution is supported.
// runPrepareDataPlugins executes PrepareDataPlugins sequentially.
func (d *Director) runPrepareDataPlugins(ctx context.Context,
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
return prepareDataPluginsWithTimeout(
prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods)

return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods)
}

func (d *Director) runAdmissionPlugins(ctx context.Context,
Expand Down
77 changes: 68 additions & 9 deletions pkg/epp/requestcontrol/plugin_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,84 @@ package requestcontrol
import (
"context"
"errors"
"fmt"
"time"

schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// executePluginsAsDAG executes PrepareData plugins as a DAG based on their dependencies asynchronously.
// So, a plugin is executed only after all its dependencies have been executed.
// If there is a cycle or any plugin fails with error, it returns an error.
func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
// Build the DAG
// The error validation happens on startup when loading the config. So, here there should not be any error.
dag, err := prepareDataGraph(plugins)
if err != nil {
return err
}
// Create a readonly map of plugin name to plugin instance.
nameToNode := map[string]PrepareDataPlugin{}
for _, plugin := range plugins {
nameToNode[plugin.TypedName().String()] = plugin
}
// Execute the DAG

// Channels to signal plugin execution completion.
pluginExecuted := make(map[string]chan error)
// The capacity of the channel is equal to the number of dependents + 1 (for itself).
capacityMap := make(map[string]int)
for pluginName := range dag {
capacityMap[pluginName]++
for _, dep := range dag[pluginName] {
capacityMap[dep]++
}
}
for pluginName, capacity := range capacityMap {
pluginExecuted[pluginName] = make(chan error, capacity)
}
for pluginName, dependents := range dag {
// Execute plugins based on dependencies.
// Wait for the dependencies to complete before executing a plugin.
go func(pN string, depds []string) {
for _, dep := range depds {
// Wait for the dependency plugin to signal completion.
if err, closed := <-pluginExecuted[dep]; closed {
if err != nil {
for range cap(pluginExecuted[pN]) {
// Notify all dependents about the failure.
pluginExecuted[pN] <- fmt.Errorf("dependency plugin %s failed: %w", dep, err)
}
// Do not execute this plugin as one of its dependencies failed.
return
}
}
}
res := nameToNode[pN].PrepareRequestData(ctx, request, pods)
for range cap(pluginExecuted[pN]) {
// Notify all dependents about the completion.
pluginExecuted[pN] <- res
}
}(pluginName, dependents)
}

// Check for errors in plugin execution.
// This will also ensure that all plugins have completed execution before returning.
for pluginName := range dag {
err := <-pluginExecuted[pluginName]
if err != nil {
return fmt.Errorf("prepare data plugin %s failed: %v", pluginName, err)
}
}
return nil
}

// prepareDataPluginsWithTimeout executes the PrepareRequestData plugins with retries and timeout.
func prepareDataPluginsWithTimeout(timeout time.Duration, plugins []PrepareDataPlugin,
ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
errCh := make(chan error, 1)
// Execute plugins sequentially in a separate goroutine
go func() {
for _, plugin := range plugins {
err := plugin.PrepareRequestData(ctx, request, pods)
if err != nil {
errCh <- errors.New("prepare data plugin " + plugin.TypedName().String() + " failed: " + err.Error())
return
}
}
errCh <- nil
errCh <- executePluginsAsDAG(plugins, ctx, request, pods)
}()

select {
Expand Down
Loading