Skip to content

Commit c80d0d3

Browse files
committed
Use buffered channel based approach to synchronize go routines. Also add unit tests
1 parent 7524eff commit c80d0d3

File tree

2 files changed

+171
-22
lines changed

2 files changed

+171
-22
lines changed

pkg/epp/requestcontrol/plugin_executor.go

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,48 +27,62 @@ import (
2727

2828
// executePluginsAsDAG executes PrepareData plugins as a DAG based on their dependencies asynchronously.
2929
// So, a plugin is executed only after all its dependencies have been executed.
30-
// If there is a cycle or other error in the DAG, it returns an error.
30+
// If there is a cycle or any plugin fails with error, it returns an error.
3131
func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
3232
// Build the DAG
3333
// The error validation happens on startup when loading the config. So, here there should not be any error.
3434
dag, err := prepareDataGraph(plugins)
3535
if err != nil {
3636
return err
3737
}
38-
// Execute the DAG
39-
40-
// Channels to signal plugin execution completion.
41-
pluginExecuted := make(map[string]chan error)
38+
// Create a readonly map of plugin name to plugin instance.
4239
nameToNode := map[string]PrepareDataPlugin{}
4340
for _, plugin := range plugins {
44-
pluginExecuted[plugin.TypedName().String()] = make(chan error)
4541
nameToNode[plugin.TypedName().String()] = plugin
4642
}
43+
// Execute the DAG
4744

45+
// Channels to signal plugin execution completion.
46+
pluginExecuted := make(map[string]chan error)
47+
// The capacity of the channel is equal to the number of dependents + 1 (for itself).
48+
capacityMap := make(map[string]int)
49+
for pluginName := range dag {
50+
capacityMap[pluginName]++
51+
for _, dep := range dag[pluginName] {
52+
capacityMap[dep]++
53+
}
54+
}
55+
for pluginName, capacity := range capacityMap {
56+
pluginExecuted[pluginName] = make(chan error, capacity)
57+
}
4858
for pluginName, dependents := range dag {
4959
// Execute plugins based on dependencies.
5060
// Wait for the dependencies to complete before executing a plugin.
51-
go func() {
52-
for _, dep := range dependents {
53-
err, open := <-pluginExecuted[dep]
54-
if !open {
55-
continue
56-
}
57-
if err != nil {
58-
// If a dependency failed, propagate the error and do not execute this plugin.
59-
pluginExecuted[pluginName] <- fmt.Errorf("dependency plugin %s failed: %w", dep, err)
61+
go func(pN string, depds []string) {
62+
// Wait for dependencies to complete.
63+
for _, dep := range depds {
64+
// Wait for the dependency plugin to signal completion.
65+
if err, closed := <-pluginExecuted[dep]; closed {
66+
if err != nil {
67+
pluginExecuted[pN] <- fmt.Errorf("dependency plugin %s failed: %w", dep, err)
68+
return
69+
}
6070
}
6171
}
62-
// Signal that the plugin has been executed.
63-
defer close(pluginExecuted[pluginName])
64-
65-
pluginExecuted[pluginName] <- nameToNode[pluginName].PrepareRequestData(ctx, request, pods)
66-
}()
72+
res := nameToNode[pN].PrepareRequestData(ctx, request, pods)
73+
for range cap(pluginExecuted[pN]) {
74+
// Notify all dependents about the completion.
75+
pluginExecuted[pN] <- res
76+
}
77+
}(pluginName, dependents)
6778
}
79+
80+
// Check for errors in plugin execution.
81+
// This will also ensure that all plugins have completed execution before returning.
6882
for pluginName := range dag {
6983
err := <-pluginExecuted[pluginName]
7084
if err != nil {
71-
return errors.New("prepare data plugin " + pluginName + " failed: " + err.Error())
85+
return fmt.Errorf("prepare data plugin %s failed: %v", pluginName, err)
7286
}
7387
}
7488
return nil
@@ -78,7 +92,6 @@ func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, reque
7892
func prepareDataPluginsWithTimeout(timeout time.Duration, plugins []PrepareDataPlugin,
7993
ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
8094
errCh := make(chan error, 1)
81-
// Execute plugins sequentially in a separate goroutine
8295
go func() {
8396
errCh <- executePluginsAsDAG(plugins, ctx, request, pods)
8497
}()

pkg/epp/requestcontrol/plugin_executor_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package requestcontrol
1919
import (
2020
"context"
2121
"errors"
22+
"sync"
2223
"testing"
2324
"time"
2425

@@ -157,3 +158,138 @@ func TestPrepareDataPluginsWithTimeout(t *testing.T) {
157158
})
158159
}
159160
}
161+
162+
type dagTestPlugin struct {
163+
mockPrepareRequestDataPlugin
164+
produces map[string]any
165+
consumes map[string]any
166+
execTime time.Time
167+
mu sync.Mutex
168+
}
169+
170+
func (p *dagTestPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
171+
p.mu.Lock()
172+
defer p.mu.Unlock()
173+
p.execTime = time.Now()
174+
return p.mockPrepareRequestDataPlugin.PrepareRequestData(ctx, request, pods)
175+
}
176+
177+
func (p *dagTestPlugin) Produces() map[string]any {
178+
return p.produces
179+
}
180+
181+
func (p *dagTestPlugin) Consumes() map[string]any {
182+
return p.consumes
183+
}
184+
185+
func TestExecutePluginsAsDAG(t *testing.T) {
186+
pluginA := &dagTestPlugin{
187+
mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "A", delay: 20 * time.Millisecond},
188+
produces: map[string]any{"keyA": nil},
189+
}
190+
pluginB := &dagTestPlugin{
191+
mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "B"},
192+
consumes: map[string]any{"keyA": nil},
193+
produces: map[string]any{"keyB": nil},
194+
}
195+
pluginC := &dagTestPlugin{
196+
mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "C"},
197+
consumes: map[string]any{"keyB": nil},
198+
}
199+
pluginD := &dagTestPlugin{
200+
mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "D"},
201+
consumes: map[string]any{"keyA": nil},
202+
}
203+
pluginE := &dagTestPlugin{
204+
mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "E"},
205+
}
206+
pluginFail := &dagTestPlugin{
207+
mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "Fail", returnErr: errors.New("plugin failed")},
208+
produces: map[string]any{"keyFail": nil},
209+
}
210+
pluginDependsOnFail := &dagTestPlugin{
211+
mockPrepareRequestDataPlugin: mockPrepareRequestDataPlugin{name: "DependsOnFail"},
212+
consumes: map[string]any{"keyFail": nil},
213+
}
214+
215+
testCases := []struct {
216+
name string
217+
plugins []PrepareDataPlugin
218+
expectErr bool
219+
checkFunc func(t *testing.T, plugins []PrepareDataPlugin)
220+
}{
221+
{
222+
name: "no plugins",
223+
plugins: []PrepareDataPlugin{},
224+
},
225+
{
226+
name: "simple linear dependency (A -> B -> C)",
227+
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC},
228+
checkFunc: func(t *testing.T, plugins []PrepareDataPlugin) {
229+
pA := plugins[0].(*dagTestPlugin)
230+
pB := plugins[1].(*dagTestPlugin)
231+
pC := plugins[2].(*dagTestPlugin)
232+
233+
assert.True(t, pA.executed, "Plugin A should have been executed")
234+
assert.True(t, pB.executed, "Plugin B should have been executed")
235+
assert.True(t, pC.executed, "Plugin C should have been executed")
236+
237+
assert.True(t, pB.execTime.After(pA.execTime), "Plugin B should execute after A")
238+
assert.True(t, pC.execTime.After(pB.execTime), "Plugin C should execute after B")
239+
},
240+
},
241+
{
242+
name: "DAG with multiple dependencies (A -> B, A -> D) and one independent (E)",
243+
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE},
244+
checkFunc: func(t *testing.T, plugins []PrepareDataPlugin) {
245+
pA := plugins[0].(*dagTestPlugin)
246+
pB := plugins[1].(*dagTestPlugin)
247+
pD := plugins[2].(*dagTestPlugin)
248+
pE := plugins[3].(*dagTestPlugin)
249+
250+
assert.True(t, pA.executed, "Plugin A should have been executed")
251+
assert.True(t, pB.executed, "Plugin B should have been executed")
252+
assert.True(t, pD.executed, "Plugin D should have been executed")
253+
assert.True(t, pE.executed, "Plugin E should have been executed")
254+
255+
assert.True(t, pB.execTime.After(pA.execTime), "Plugin B should execute after A")
256+
assert.True(t, pD.execTime.After(pA.execTime), "Plugin D should execute after A")
257+
},
258+
},
259+
{
260+
name: "dependency fails",
261+
plugins: []PrepareDataPlugin{pluginFail, pluginDependsOnFail},
262+
expectErr: true,
263+
checkFunc: func(t *testing.T, plugins []PrepareDataPlugin) {
264+
pF := plugins[0].(*dagTestPlugin)
265+
pDOF := plugins[1].(*dagTestPlugin)
266+
267+
assert.True(t, pF.executed, "Failing plugin should have been executed")
268+
assert.False(t, pDOF.executed, "Plugin depending on fail should not be executed")
269+
},
270+
},
271+
}
272+
273+
for _, tc := range testCases {
274+
t.Run(tc.name, func(t *testing.T) {
275+
// Reset execution state for plugins
276+
for _, p := range tc.plugins {
277+
plugin := p.(*dagTestPlugin)
278+
plugin.executed = false
279+
plugin.execTime = time.Time{}
280+
}
281+
282+
err := executePluginsAsDAG(tc.plugins, context.Background(), &schedulingtypes.LLMRequest{}, nil)
283+
284+
if tc.expectErr {
285+
assert.Error(t, err)
286+
} else {
287+
assert.NoError(t, err)
288+
}
289+
290+
if tc.checkFunc != nil {
291+
tc.checkFunc(t, tc.plugins)
292+
}
293+
})
294+
}
295+
}

0 commit comments

Comments
 (0)