@@ -21,6 +21,7 @@ import { oaCompatHelper } from "./provider/openai-compatible"
2121import { createRateLimiter } from "./rateLimiter"
2222import { createDataDumper } from "./dataDumper"
2323import { createTrialLimiter } from "./trialLimiter"
24+ import { createStickyTracker } from "./stickyProviderTracker"
2425
2526type ZenData = Awaited < ReturnType < typeof ZenData . list > >
2627type RetryOptions = {
@@ -68,9 +69,11 @@ export async function handler(
6869 const isTrial = await trialLimiter ?. isTrial ( )
6970 const rateLimiter = createRateLimiter ( modelInfo . id , modelInfo . rateLimit , ip )
7071 await rateLimiter ?. check ( )
72+ const stickyTracker = createStickyTracker ( modelInfo . stickyProvider ?? false , sessionId )
73+ const stickyProvider = await stickyTracker ?. get ( )
7174
7275 const retriableRequest = async ( retry : RetryOptions = { excludeProviders : [ ] , retryCount : 0 } ) => {
73- const providerInfo = selectProvider ( zenData , modelInfo , sessionId , isTrial ?? false , retry )
76+ const providerInfo = selectProvider ( zenData , modelInfo , sessionId , isTrial ?? false , retry , stickyProvider )
7477 const authInfo = await authenticate ( modelInfo , providerInfo )
7578 validateBilling ( authInfo , modelInfo )
7679 validateModelSettings ( authInfo )
@@ -121,6 +124,9 @@ export async function handler(
121124 dataDumper ?. provideModel ( providerInfo . storeModel )
122125 dataDumper ?. provideRequest ( reqBody )
123126
127+ // Store sticky provider
128+ await stickyTracker ?. set ( providerInfo . id )
129+
124130 // Scrub response headers
125131 const resHeaders = new Headers ( )
126132 const keepHeaders = [ "content-type" , "cache-control" ]
@@ -289,12 +295,18 @@ export async function handler(
289295 sessionId : string ,
290296 isTrial : boolean ,
291297 retry : RetryOptions ,
298+ stickyProvider : string | undefined ,
292299 ) {
293300 const provider = ( ( ) => {
294301 if ( isTrial ) {
295302 return modelInfo . providers . find ( ( provider ) => provider . id === modelInfo . trial ! . provider )
296303 }
297304
305+ if ( stickyProvider ) {
306+ const provider = modelInfo . providers . find ( ( provider ) => provider . id === stickyProvider )
307+ if ( provider ) return provider
308+ }
309+
298310 if ( retry . retryCount === MAX_RETRIES ) {
299311 return modelInfo . providers . find ( ( provider ) => provider . id === modelInfo . fallbackProvider )
300312 }
0 commit comments