Skip to content

Commit 0334dfa

Browse files
authored
Refactor "Update batch AI to use new Selector" (#3428)
1 parent 59ea865 commit 0334dfa

File tree

3 files changed

+53
-42
lines changed

3 files changed

+53
-42
lines changed

server/ai_session.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,12 @@ func NewAISessionSelector(ctx context.Context, cap core.Capability, modelID stri
191191
if cap == core.Capability_LiveVideoToVideo {
192192
// For Realtime Video AI, we don't use any features of MinLSSelector (preferring known sessions, etc.),
193193
// We always select a fresh session which has the lowest initial latency
194-
useInitialLatencyToSort := true
195-
warmSel = NewSelector(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, warmCaps, useInitialLatencyToSort)
196-
coldSel = NewSelector(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, coldCaps, useInitialLatencyToSort)
194+
warmSel = NewSelector(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, warmCaps)
195+
coldSel = NewSelector(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, coldCaps)
197196
} else {
198-
//sort sessions based on current latency score
199-
useInitialLatencyToSort := false
200-
warmSel = NewSelector(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, warmCaps, useInitialLatencyToSort)
201-
coldSel = NewSelector(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, coldCaps, useInitialLatencyToSort)
197+
// sort sessions based on current latency score
198+
warmSel = NewSelectorOrderByLatencyScore(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, warmCaps)
199+
coldSel = NewSelectorOrderByLatencyScore(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, coldCaps)
202200
}
203201

204202
warmPool := NewAISessionPool(warmSel, suspender, penalty)

server/selection.go

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -94,49 +94,60 @@ func (r *storeStakeReader) Stakes(addrs []ethcommon.Address) (map[ethcommon.Addr
9494
type Selector struct {
9595
sessions []*BroadcastSession
9696

97-
stakeRdr stakeReader
98-
selectionAlgorithm common.SelectionAlgorithm
99-
perfScore *common.PerfScore
100-
capabilities common.CapabilityComparator
101-
useInitialLatencyToSort bool
97+
stakeRdr stakeReader
98+
selectionAlgorithm common.SelectionAlgorithm
99+
perfScore *common.PerfScore
100+
capabilities common.CapabilityComparator
101+
sortCompFunc func(sess1, sess2 *BroadcastSession) bool
102102
}
103103

104-
func NewSelector(stakeRdr stakeReader, selectionAlgorithm common.SelectionAlgorithm, perfScore *common.PerfScore, capabilities common.CapabilityComparator, useInitialLatencyToSort bool) *Selector {
104+
func NewSelector(stakeRdr stakeReader, selectionAlgorithm common.SelectionAlgorithm, perfScore *common.PerfScore, capabilities common.CapabilityComparator) *Selector {
105+
// By default, sort by initial latency
106+
sortCompFunc := func(sess1, sess2 *BroadcastSession) bool {
107+
return sess1.InitialLatency < sess2.InitialLatency
108+
}
105109
return &Selector{
106-
stakeRdr: stakeRdr,
107-
selectionAlgorithm: selectionAlgorithm,
108-
perfScore: perfScore,
109-
capabilities: capabilities,
110-
useInitialLatencyToSort: useInitialLatencyToSort,
110+
stakeRdr: stakeRdr,
111+
selectionAlgorithm: selectionAlgorithm,
112+
perfScore: perfScore,
113+
capabilities: capabilities,
114+
sortCompFunc: sortCompFunc,
115+
}
116+
}
117+
118+
func NewSelectorOrderByLatencyScore(stakeRdr stakeReader, selectionAlgorithm common.SelectionAlgorithm, perfScore *common.PerfScore, capabilities common.CapabilityComparator) *Selector {
119+
sortCompFunc := func(sess1, sess2 *BroadcastSession) bool {
120+
return sess1.LatencyScore < sess2.LatencyScore
121+
}
122+
return &Selector{
123+
stakeRdr: stakeRdr,
124+
selectionAlgorithm: selectionAlgorithm,
125+
perfScore: perfScore,
126+
capabilities: capabilities,
127+
sortCompFunc: sortCompFunc,
111128
}
112129
}
113130

114131
func (s *Selector) Add(sessions []*BroadcastSession) {
115132
s.sessions = append(s.sessions, sessions...)
116-
s.sortByLatency()
133+
s.sort()
117134
}
118135

119136
func (s *Selector) Complete(sess *BroadcastSession) {
120137
s.sessions = append(s.sessions, sess)
121-
s.sortByLatency()
122-
}
123-
124-
func (s *Selector) sortByLatency() {
125-
if s.useInitialLatencyToSort {
126-
sort.Slice(s.sessions, func(i, j int) bool {
127-
return s.sessions[i].InitialLatency < s.sessions[j].InitialLatency
128-
})
129-
} else {
130-
sort.Slice(s.sessions, func(i, j int) bool {
131-
return s.sessions[i].LatencyScore < s.sessions[j].LatencyScore
132-
})
133-
}
138+
s.sort()
139+
}
140+
141+
func (s *Selector) sort() {
142+
sort.Slice(s.sessions, func(i, j int) bool {
143+
return s.sortCompFunc(s.sessions[i], s.sessions[j])
144+
})
134145
}
135146

136147
func (s *Selector) Select(ctx context.Context) *BroadcastSession {
137148
availableOrchestrators := toOrchestrators(s.sessions)
138149
sess := s.selectUnknownSession(ctx)
139-
s.sortByLatency()
150+
s.sort()
140151
clog.V(common.DEBUG).Infof(ctx, "Selected orchestrator %s from available list: %v", toOrchestrator(sess), availableOrchestrators)
141152
return sess
142153
}

server/selection_test.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ func TestSelector_Select(t *testing.T) {
162162
assert := assert.New(t)
163163

164164
// given
165-
sel := NewSelector(nil, stubSelectionAlgorithm{}, nil, nil, true)
165+
sel := NewSelector(nil, stubSelectionAlgorithm{}, nil, nil)
166166
sessions := []*BroadcastSession{
167167
{PMSessionID: "session-1", InitialLatency: 400 * time.Millisecond},
168168
{PMSessionID: "session-2", InitialLatency: 200 * time.Millisecond},
@@ -185,7 +185,7 @@ func TestSelector_CompleteAndSelect(t *testing.T) {
185185
assert := assert.New(t)
186186

187187
// given
188-
sel := NewSelector(nil, stubSelectionAlgorithm{}, nil, nil, true)
188+
sel := NewSelector(nil, stubSelectionAlgorithm{}, nil, nil)
189189
sessions := []*BroadcastSession{
190190
{PMSessionID: "session-1", InitialLatency: 400 * time.Millisecond},
191191
{PMSessionID: "session-2", InitialLatency: 200 * time.Millisecond},
@@ -213,7 +213,7 @@ func TestSelector_Size(t *testing.T) {
213213
assert := assert.New(t)
214214

215215
// given
216-
sel := NewSelector(nil, stubSelectionAlgorithm{}, nil, nil, true)
216+
sel := NewSelector(nil, stubSelectionAlgorithm{}, nil, nil)
217217
sessions := []*BroadcastSession{
218218
{PMSessionID: "session-1", InitialLatency: 400 * time.Millisecond},
219219
{PMSessionID: "session-2", InitialLatency: 200 * time.Millisecond},
@@ -238,11 +238,10 @@ func TestSelector_Size(t *testing.T) {
238238
assert.Nil(sel.Select(context.Background()))
239239
}
240240

241-
func TestSelector_SortByLatency(t *testing.T) {
241+
func TestSelector_SortByInitialLatency(t *testing.T) {
242242
assert := assert.New(t)
243243

244-
// sort by initial latency
245-
sel := NewSelector(nil, stubSelectionAlgorithm{}, nil, nil, true)
244+
sel := NewSelector(nil, stubSelectionAlgorithm{}, nil, nil)
246245
sessions := []*BroadcastSession{
247246
{PMSessionID: "session-1", InitialLatency: 400 * time.Millisecond},
248247
{PMSessionID: "session-2", InitialLatency: 200 * time.Millisecond},
@@ -253,10 +252,13 @@ func TestSelector_SortByLatency(t *testing.T) {
253252
assert.Equal("session-2", sel.sessions[0].PMSessionID)
254253
assert.Equal("session-1", sel.sessions[1].PMSessionID)
255254
assert.Equal("session-3", sel.sessions[2].PMSessionID)
255+
}
256+
257+
func TestSelector_SortByLatencyScore(t *testing.T) {
258+
assert := assert.New(t)
256259

257-
// sort by initial latency
258-
sel = NewSelector(nil, stubSelectionAlgorithm{}, nil, nil, false)
259-
sessions = []*BroadcastSession{
260+
sel := NewSelectorOrderByLatencyScore(nil, stubSelectionAlgorithm{}, nil, nil)
261+
sessions := []*BroadcastSession{
260262
{PMSessionID: "session-1", InitialLatency: 400 * time.Millisecond, LatencyScore: 0.001},
261263
{PMSessionID: "session-2", InitialLatency: 200 * time.Millisecond, LatencyScore: 0.01},
262264
{PMSessionID: "session-3", InitialLatency: 600 * time.Millisecond, LatencyScore: 0.08},

0 commit comments

Comments
 (0)