@@ -3,6 +3,7 @@ package proxy
33import (
44 "log/slog"
55 "math/rand"
6+ "net/http"
67 "sync"
78 "time"
89)
@@ -11,14 +12,17 @@ import (
1112// used as toleration when comparing floats.
1213const epsilon = 1e-9
1314
15+ // ProxyKeyHeader is the HTTP header used for sticky session identification
16+ const ProxyKeyHeader = "X-PROXY-KEY"
17+
1418// LoadBalancer is an interface for load balancing algorithms. It provides
1519// methods to update the list of available servers and to select the next
1620// server to be used.
1721type LoadBalancer interface {
1822 // Update updates the list of available servers.
1923 Update ([]* Server )
2024 // Next returns the next server to be used based on the load balancing algorithm.
21- Next () * Server
25+ Next (* http. Request ) * Server
2226}
2327
2428// RoundRobin is a simple load balancer that distributes incoming requests
@@ -43,7 +47,7 @@ func NewRoundRobin(log *slog.Logger) *RoundRobin {
4347
4448// Next returns the next server to be used based on the round-robin algorithm.
4549// If the selected server is unhealthy, it will recursively try the next server.
46- func (rr * RoundRobin ) Next () * Server {
50+ func (rr * RoundRobin ) Next (_ * http. Request ) * Server {
4751 rr .mu .Lock ()
4852 if len (rr .servers ) == 0 {
4953 return nil
@@ -56,7 +60,7 @@ func (rr *RoundRobin) Next() *Server {
5660 return server
5761 }
5862 rr .log .Warn ("server is unhealthy, trying next" , "name" , server .name )
59- return rr .Next ()
63+ return rr .Next (nil )
6064}
6165
6266// Update updates the list of available servers.
@@ -104,7 +108,7 @@ func NewLatencyBased(log *slog.Logger) *LatencyBased {
104108// The random number will fall into one of these ranges, effectively selecting
105109// a server based on its latency rate. This approach works regardless of the order of
106110// the servers, so there's no need to sort them based on latency or rate.
107- func (rr * LatencyBased ) Next () * Server {
111+ func (rr * LatencyBased ) Next (_ * http. Request ) * Server {
108112 rr .mu .Lock ()
109113 defer rr .mu .Unlock ()
110114
@@ -159,3 +163,149 @@ func (rr *LatencyBased) Update(servers []*Server) {
159163 rr .servers [i ].Rate /= totalInverse
160164 }
161165}
166+
167+ // StickyLatencyBased is a load balancer that combines session affinity with latency-based routing.
168+ // It embeds LatencyBased to reuse latency calculation and server management functionality,
169+ // while adding session stickiness using industry-standard headers and cookies.
170+ // Warning: This load balancer type is not effective if running alongside other replicas as
171+ // the state is not shared between replicas.
172+ type StickyLatencyBased struct {
173+ // LatencyBased provides the core latency-based selection functionality
174+ * LatencyBased
175+ // sessionMap maps session identifiers to server references for sticky sessions.
176+ sessionMap map [string ]* Server
177+ // sessionMu is a separate mutex for session-specific operations to avoid lock contention
178+ sessionMu sync.RWMutex
179+ // sessionTimeout defines how long sessions are kept in memory.
180+ sessionTimeout time.Duration
181+ // sessionCleanupTicker periodically cleans up expired sessions.
182+ sessionCleanupTicker * time.Ticker
183+ // sessionTimestamps tracks when sessions were last accessed.
184+ sessionTimestamps map [string ]time.Time
185+ }
186+
187+ // NewStickyLatencyBased returns a new StickyLatencyBased load balancer instance.
188+ // It embeds a LatencyBased load balancer and adds session management functionality.
189+ func NewStickyLatencyBased (log * slog.Logger , sessionTimeout time.Duration ) * StickyLatencyBased {
190+ if sessionTimeout == 0 {
191+ sessionTimeout = 30 * time .Minute // Default session timeout
192+ }
193+
194+ slb := & StickyLatencyBased {
195+ LatencyBased : NewLatencyBased (log ),
196+ sessionMap : make (map [string ]* Server ),
197+ sessionTimestamps : make (map [string ]time.Time ),
198+ sessionTimeout : sessionTimeout ,
199+ }
200+
201+ // Start cleanup routine for expired sessions
202+ slb .sessionCleanupTicker = time .NewTicker (5 * time .Minute )
203+ go slb .cleanupExpiredSessions ()
204+
205+ return slb
206+ }
207+
208+ // Next returns the next server based on session affinity and latency.
209+ // It first checks for existing session identifiers in headers or cookies,
210+ // then falls back to the embedded LatencyBased selection for new sessions.
211+ func (slb * StickyLatencyBased ) Next (req * http.Request ) * Server {
212+ if req == nil {
213+ slb .log .Warn ("provided request is nil" )
214+ return slb .LatencyBased .Next (nil )
215+ }
216+
217+ slb .LatencyBased .mu .Lock ()
218+ if len (slb .LatencyBased .servers ) == 0 {
219+ slb .LatencyBased .mu .Unlock ()
220+ return nil
221+ }
222+ slb .LatencyBased .mu .Unlock ()
223+
224+ sessionID := slb .extractSessionID (req )
225+
226+ if sessionID != "" {
227+ slb .sessionMu .RLock ()
228+ if server , exists := slb .sessionMap [sessionID ]; exists {
229+ // Check if session has timed out (cache miss scenario)
230+ lastAccessed := slb .sessionTimestamps [sessionID ]
231+ if time .Since (lastAccessed ) > slb .sessionTimeout {
232+ slb .sessionMu .RUnlock ()
233+
234+ // Session timed out, clean it up
235+ slb .sessionMu .Lock ()
236+ delete (slb .sessionMap , sessionID )
237+ delete (slb .sessionTimestamps , sessionID )
238+ slb .sessionMu .Unlock ()
239+
240+ slb .log .Info ("session timed out, removed" ,
241+ "session_id" , sessionID ,
242+ "server" , server .name ,
243+ "last_accessed" , lastAccessed )
244+ } else {
245+ // Session is valid (cache hit scenario)
246+ slb .sessionMu .RUnlock ()
247+
248+ // Update session timestamp
249+ slb .sessionMu .Lock ()
250+ slb .sessionTimestamps [sessionID ] = time .Now ()
251+ slb .sessionMu .Unlock ()
252+ return server
253+ }
254+ } else {
255+ slb .sessionMu .RUnlock ()
256+ }
257+ }
258+
259+ // No existing session or unhealthy server, use embedded LatencyBased selection
260+ server := slb .LatencyBased .Next (req )
261+
262+ if server != nil && sessionID != "" {
263+ // Create new session mapping
264+ slb .sessionMu .Lock ()
265+ slb .sessionMap [sessionID ] = server
266+ slb .sessionTimestamps [sessionID ] = time .Now ()
267+ slb .sessionMu .Unlock ()
268+
269+ slb .log .Debug ("created new sticky session" ,
270+ "session_id" , sessionID ,
271+ "server" , server .name )
272+ }
273+
274+ return server
275+ }
276+
277+ // extractSessionID extracts session identifier from HTTP request.
278+ // It only checks for the X-PROXY-KEY header. If not provided, returns empty string
279+ // which will cause the load balancer to use normal latency-based selection.
280+ func (slb * StickyLatencyBased ) extractSessionID (req * http.Request ) string {
281+ return req .Header .Get (ProxyKeyHeader )
282+ }
283+
284+ // Update updates the list of available servers using the embedded LatencyBased functionality
285+ // and cleans up session mappings for servers that no longer exist.
286+ func (slb * StickyLatencyBased ) Update (servers []* Server ) {
287+ slb .LatencyBased .Update (servers )
288+ }
289+
290+ // cleanupExpiredSessions runs in a background goroutine to clean up expired sessions
291+ func (slb * StickyLatencyBased ) cleanupExpiredSessions () {
292+ for range slb .sessionCleanupTicker .C {
293+ slb .sessionMu .Lock ()
294+ now := time .Now ()
295+ for sessionID , timestamp := range slb .sessionTimestamps {
296+ if now .Sub (timestamp ) > slb .sessionTimeout {
297+ delete (slb .sessionMap , sessionID )
298+ delete (slb .sessionTimestamps , sessionID )
299+ slb .log .Debug ("cleaned up expired session" , "session_id" , sessionID )
300+ }
301+ }
302+ slb .sessionMu .Unlock ()
303+ }
304+ }
305+
306+ // Stop stops the cleanup ticker and releases resources
307+ func (slb * StickyLatencyBased ) Stop () {
308+ if slb .sessionCleanupTicker != nil {
309+ slb .sessionCleanupTicker .Stop ()
310+ }
311+ }
0 commit comments