Skip to content

Commit 5f64c1d

Browse files
authored
refactor: Allow export prefix SchedulingContextState for use across plugins (#1063)
* refactor: Allow export prefix SchedulingContextState for use across plugins Signed-off-by: Kfir Toledo <[email protected]> * refactor: replace getPrefixState with generic ReadCycleStateAs Signed-off-by: Kfir Toledo <[email protected]> --------- Signed-off-by: Kfir Toledo <[email protected]>
1 parent b4d1e67 commit 5f64c1d

File tree

3 files changed

+32
-29
lines changed

3 files changed

+32
-29
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -90,25 +90,25 @@ func (s ServerID) String() string {
9090
}
9191

9292
// compile-time type validation
93-
var _ types.StateData = &schedulingContextState{}
93+
var _ types.StateData = &SchedulingContextState{}
9494

95-
// This is the state of this plugin to be used during a scheduling cycle.
96-
type schedulingContextState struct {
95+
// SchedulingContextState is the state of this plugin to be used during a scheduling cycle.
96+
type SchedulingContextState struct {
9797
// PrefixHashes is a list of prefix hashes of the request prompt broken into blocks.
9898
PrefixHashes []BlockHash
9999
// A map of server to its longest prefix cache match length.
100100
PrefixCacheServers map[ServerID]int
101101
}
102102

103-
func (s *schedulingContextState) Clone() types.StateData {
103+
func (s *SchedulingContextState) Clone() types.StateData {
104104
prefixHashes := make([]BlockHash, len(s.PrefixHashes))
105105
copy(prefixHashes, s.PrefixHashes)
106106
prefixCacheServers := make(map[ServerID]int, len(s.PrefixCacheServers))
107107
for key, value := range s.PrefixCacheServers {
108108
prefixCacheServers[key] = value
109109
}
110110

111-
return &schedulingContextState{
111+
return &SchedulingContextState{
112112
PrefixHashes: prefixHashes,
113113
PrefixCacheServers: prefixCacheServers,
114114
}
@@ -171,7 +171,7 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
171171
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
172172
// pre score step, hashing prompt and find longest prefix match.
173173
hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
174-
state := &schedulingContextState{
174+
state := &SchedulingContextState{
175175
PrefixHashes: hashes,
176176
PrefixCacheServers: m.matchLongestPrefix(ctx, hashes),
177177
}
@@ -199,7 +199,7 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
199199
// PostCycle records in the plugin cache the result of the scheduling selection.
200200
func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) {
201201
targetPod := res.TargetPod.GetPod()
202-
state, err := m.getPrefixState(cycleState)
202+
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
203203
if err != nil {
204204
log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state")
205205
return
@@ -235,22 +235,6 @@ func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
235235
return res
236236
}
237237

238-
// getPrefixState returns the cycle state as a schedulingContextState.
239-
func (m *Plugin) getPrefixState(cycleState *types.CycleState) (*schedulingContextState, error) {
240-
prefixStateKey := types.StateKey(m.Type())
241-
state, err := cycleState.Read(prefixStateKey)
242-
if err != nil {
243-
return nil, fmt.Errorf("failed reading %q from CycleState: %w", prefixStateKey, err)
244-
}
245-
246-
prefixSchedulingState, ok := state.(*schedulingContextState)
247-
if !ok {
248-
return nil, fmt.Errorf("invalid Prefix state, got type %T", state)
249-
}
250-
251-
return prefixSchedulingState, nil
252-
}
253-
254238
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
255239
// hash(0) is the hash of the model name, since different models generally don't share prefix cache.
256240
// For block i, hash(i) = hash(block i content, hash(i-1)).

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func TestPrefixPlugin(t *testing.T) {
5050
}
5151
cycleState1 := types.NewCycleState()
5252
scores := plugin.Score(context.Background(), cycleState1, req1, pods)
53-
state, err := plugin.getPrefixState(cycleState1)
53+
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState1, PrefixCachePluginType)
5454
assert.NoError(t, err)
5555
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
5656
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
@@ -71,7 +71,7 @@ func TestPrefixPlugin(t *testing.T) {
7171
}
7272
cycleState2 := types.NewCycleState()
7373
scores = plugin.Score(context.Background(), cycleState2, req2, pods)
74-
state, err = plugin.getPrefixState(cycleState2)
74+
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState2, PrefixCachePluginType)
7575
assert.NoError(t, err)
7676
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
7777
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
@@ -91,7 +91,7 @@ func TestPrefixPlugin(t *testing.T) {
9191
}
9292
cycleState3 := types.NewCycleState()
9393
scores = plugin.Score(context.Background(), cycleState3, req3, pods)
94-
state, err = plugin.getPrefixState(cycleState3)
94+
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState3, PrefixCachePluginType)
9595
assert.NoError(t, err)
9696
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
9797
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
@@ -110,7 +110,7 @@ func TestPrefixPlugin(t *testing.T) {
110110
}
111111
cycleState4 := types.NewCycleState()
112112
scores = plugin.Score(context.Background(), cycleState4, req4, pods)
113-
state, err = plugin.getPrefixState(cycleState4)
113+
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState4, PrefixCachePluginType)
114114
assert.NoError(t, err)
115115
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
116116
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
@@ -129,7 +129,7 @@ func TestPrefixPlugin(t *testing.T) {
129129
}
130130
cycleState5 := types.NewCycleState()
131131
scores = plugin.Score(context.Background(), cycleState5, req5, pods)
132-
state, err = plugin.getPrefixState(cycleState5)
132+
state, err = types.ReadCycleStateKey[*SchedulingContextState](cycleState5, PrefixCachePluginType)
133133
assert.NoError(t, err)
134134
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
135135
// Input size is 12, hash block size is 4, so 3 hashes will be calculated.
@@ -183,7 +183,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
183183
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPod: pod})
184184

185185
// Second cycle: validate internal state
186-
state, err := plugin.getPrefixState(cycleState)
186+
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
187187
assert.NoError(b, err)
188188
expectedHashes := int(math.Min(float64(maxPrefixBlocks+1), float64(len(req.Prompt)/blockSize+1))) // the extra one is for the model.
189189
assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect")

pkg/epp/scheduling/types/cycle_state.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package types
1818

1919
import (
2020
"errors"
21+
"fmt"
2122
"sync"
2223
)
2324

@@ -90,3 +91,21 @@ func (c *CycleState) Write(key StateKey, val StateData) {
9091
func (c *CycleState) Delete(key StateKey) {
9192
c.storage.Delete(key)
9293
}
94+
95+
// ReadCycleStateKey retrieves data with the given key from CycleState and asserts it to type T.
96+
// Returns an error if the key is not found or the type assertion fails.
97+
func ReadCycleStateKey[T StateData](c *CycleState, key StateKey) (T, error) {
98+
var zero T
99+
100+
raw, err := c.Read(key)
101+
if err != nil {
102+
return zero, err
103+
}
104+
105+
val, ok := raw.(T)
106+
if !ok {
107+
return zero, fmt.Errorf("unexpected type for key %q: got %T", key, raw)
108+
}
109+
110+
return val, nil
111+
}

0 commit comments

Comments
 (0)