@@ -9,9 +9,11 @@ import (
99 "sync"
1010 "time"
1111
12+ ethcommon "github.com/ethereum/go-ethereum/common"
1213 "github.com/livepeer/go-livepeer/clog"
1314 "github.com/livepeer/go-livepeer/common"
1415 "github.com/livepeer/go-livepeer/core"
16+ "github.com/livepeer/go-livepeer/pm"
1517 "github.com/livepeer/go-tools/drivers"
1618 "github.com/livepeer/lpms/stream"
1719)
@@ -273,6 +275,8 @@ func (sel *AISessionSelector) Remove(sess *AISession) {
273275}
274276
275277func (sel * AISessionSelector ) Refresh (ctx context.Context ) error {
278+ oldBalances , oldSenderSessions := sel .getBalances ()
279+
276280 sessions , err := sel .getSessions (ctx )
277281 if err != nil {
278282 return err
@@ -293,6 +297,9 @@ func (sel *AISessionSelector) Refresh(ctx context.Context) error {
293297 continue
294298 }
295299
300+ // update session to persist payment balances
301+ updateSessionForAI (sess , sel .cap , sel .modelID , sel .node .Balances , oldBalances , oldSenderSessions )
302+
296303 if modelConstraint .Warm {
297304 warmSessions = append (warmSessions , sess )
298305 } else {
@@ -308,6 +315,24 @@ func (sel *AISessionSelector) Refresh(ctx context.Context) error {
308315 return nil
309316}
310317
318+ func (sel * AISessionSelector ) getBalances () (map [string ]Balance , map [string ]pm.Sender ) {
319+ balances := make (map [string ]Balance )
320+ senders := make (map [string ]pm.Sender )
321+ for _ , sess := range sel .warmPool .sessMap {
322+ balances [sess .Transcoder ()] = sess .Balance
323+ senders [sess .Transcoder ()] = sess .Sender
324+ }
325+
326+ for _ , sess := range sel .coldPool .sessMap {
327+ if _ , ok := balances [sess .Transcoder ()]; ! ok {
328+ balances [sess .Transcoder ()] = sess .Balance
329+ senders [sess .Transcoder ()] = sess .Sender
330+ }
331+ }
332+
333+ return balances , senders
334+ }
335+
311336func (sel * AISessionSelector ) getSessions (ctx context.Context ) ([]* BroadcastSession , error ) {
312337 // No warm constraints applied here because we don't want to filter out orchs based on warm criteria at discovery time
313338 // Instead, we want all orchs that support the model and then will filter for orchs that have a warm model separately
@@ -384,9 +409,19 @@ func (c *AISessionManager) Select(ctx context.Context, cap core.Capability, mode
384409 return nil , nil
385410 }
386411
387- if err := refreshSessionIfNeeded (ctx , sess .BroadcastSession ); err != nil {
412+ //send a temp session to be refreshed
413+ // updateSession in broadcast.go updates the orchestrator OS and ticket params.
414+ // it also updates the pm.Sender session using ticket params and the Balance using the auth token.
415+ // we want to persist these to new
416+ newSess := * sess .BroadcastSession
417+ newSess .PMSessionID = strconv .Itoa (int (cap )) + "_" + modelID + "_" + "temp"
418+ newSess .Sender .StartSessionByID (* pmTicketParams (newSess .OrchestratorInfo .TicketParams ), newSess .PMSessionID )
419+ if err := refreshSessionIfNeeded (ctx , & newSess ); err != nil {
388420 return nil , err
389421 }
422+ sess .BroadcastSession .OrchestratorInfo = newSess .OrchestratorInfo
423+ sess .Sender .UpdateSessionByID (* pmTicketParams (sess .OrchestratorInfo .TicketParams ), sess .PMSessionID )
424+ //updateSessionForAI(sess.BroadcastSession, cap, modelID, c.node.Balances)
390425
391426 return sess , nil
392427}
@@ -432,3 +467,25 @@ func (c *AISessionManager) getSelector(ctx context.Context, cap core.Capability,
432467
433468 return sel , nil
434469}
470+
471+ func updateSessionForAI (sess * BroadcastSession , cap core.Capability , modelID string , balances * core.AddressBalances , oldBalances map [string ]Balance , oldSenderSessions map [string ]pm.Sender ) {
472+ //clean up other session
473+ sess .CleanupSession (sess .PMSessionID )
474+ // override PMSessionID to track tickets per pipeline/model
475+ transcoderUrl := sess .Transcoder ()
476+ sess .lock .Lock ()
477+ defer sess .lock .Unlock ()
478+ sess .PMSessionID = strconv .Itoa (int (cap )) + "_" + modelID
479+ // save balance between refreshes
480+ if oldBalance , ok := oldBalances [transcoderUrl ]; ok {
481+ sess .Balance = oldBalance
482+ } else {
483+ sess .Balance = core .NewBalance (ethcommon .BytesToAddress (sess .OrchestratorInfo .TicketParams .Recipient ), core .ManifestID (strconv .Itoa (int (cap ))+ "_" + modelID ), balances )
484+ }
485+ // save sender sessions between refreshes
486+ if oldSenderSession , ok := oldSenderSessions [transcoderUrl ]; ok {
487+ sess .Sender = oldSenderSession
488+ } else {
489+ sess .Sender .StartSessionByID (* pmTicketParams (sess .OrchestratorInfo .TicketParams ), sess .PMSessionID )
490+ }
491+ }
0 commit comments