Skip to content

Commit d958d92

Browse files
committed
feat: add a sticky session latency-based load balancer
1 parent a14fada commit d958d92

File tree

7 files changed

+539
-14
lines changed

7 files changed

+539
-14
lines changed

cmd/main.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ func runProxy(cfg config.Config) {
165165
},
166166
}
167167
seeder := seed.New(seederCfg, log, rpcListener, restListener, grpcListener)
168-
rpcProxyHandler := proxy.NewRPCProxy(rpcListener, cfg.Health, log, proxy.NewLatencyBased(log))
169-
restProxyHandler := proxy.NewRestProxy(restListener, cfg.Health, log, proxy.NewLatencyBased(log))
170-
grpcProxyHandler := proxy.NewGRPCProxy(grpcListener, log, proxy.NewLatencyBased(log))
168+
rpcProxyHandler := proxy.NewRPCProxy(rpcListener, cfg.Health, log, proxy.NewStickyLatencyBased(log, 6*time.Second))
169+
restProxyHandler := proxy.NewRestProxy(restListener, cfg.Health, log, proxy.NewStickyLatencyBased(log, 6*time.Second))
170+
grpcProxyHandler := proxy.NewGRPCProxy(grpcListener, log, proxy.NewStickyLatencyBased(log, 6*time.Second))
171171

172172
ctx, proxyCtxCancel := context.WithCancel(context.Background())
173173
defer proxyCtxCancel()
@@ -264,7 +264,7 @@ func runProxy(cfg config.Config) {
264264
}
265265

266266
func main() {
267-
var v = viper.New()
267+
v := viper.New()
268268

269269
if err := NewRootCmd(v).Execute(); err != nil {
270270
log.Fatalf("failed to execute command: %v", err)

internal/proxy/balancer.go

Lines changed: 154 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package proxy
33
import (
44
"log/slog"
55
"math/rand"
6+
"net/http"
67
"sync"
78
"time"
89
)
@@ -11,14 +12,17 @@ import (
1112
// used as toleration when comparing floats.
1213
const epsilon = 1e-9
1314

15+
// ProxyKeyHeader is the HTTP header used for sticky session identification
16+
const ProxyKeyHeader = "X-PROXY-KEY"
17+
1418
// LoadBalancer is an interface for load balancing algorithms. It provides
1519
// methods to update the list of available servers and to select the next
1620
// server to be used.
1721
type LoadBalancer interface {
1822
// Update updates the list of available servers.
1923
Update([]*Server)
2024
// Next returns the next server to be used based on the load balancing algorithm.
21-
Next() *Server
25+
Next(*http.Request) *Server
2226
}
2327

2428
// RoundRobin is a simple load balancer that distributes incoming requests
@@ -43,7 +47,7 @@ func NewRoundRobin(log *slog.Logger) *RoundRobin {
4347

4448
// Next returns the next server to be used based on the round-robin algorithm.
4549
// If the selected server is unhealthy, it will recursively try the next server.
46-
func (rr *RoundRobin) Next() *Server {
50+
func (rr *RoundRobin) Next(_ *http.Request) *Server {
4751
rr.mu.Lock()
4852
if len(rr.servers) == 0 {
4953
return nil
@@ -56,7 +60,7 @@ func (rr *RoundRobin) Next() *Server {
5660
return server
5761
}
5862
rr.log.Warn("server is unhealthy, trying next", "name", server.name)
59-
return rr.Next()
63+
return rr.Next(nil)
6064
}
6165

6266
// Update updates the list of available servers.
@@ -104,7 +108,7 @@ func NewLatencyBased(log *slog.Logger) *LatencyBased {
104108
// The random number will fall into one of these ranges, effectively selecting
105109
// a server based on its latency rate. This approach works regardless of the order of
106110
// the servers, so there's no need to sort them based on latency or rate.
107-
func (rr *LatencyBased) Next() *Server {
111+
func (rr *LatencyBased) Next(_ *http.Request) *Server {
108112
rr.mu.Lock()
109113
defer rr.mu.Unlock()
110114

@@ -159,3 +163,149 @@ func (rr *LatencyBased) Update(servers []*Server) {
159163
rr.servers[i].Rate /= totalInverse
160164
}
161165
}
166+
167+
// StickyLatencyBased is a load balancer that combines session affinity with latency-based routing.
168+
// It embeds LatencyBased to reuse latency calculation and server management functionality,
169+
// while adding session stickiness using industry-standard headers and cookies.
170+
// Warning: This load balancer type is not effective if running alongside other replicas as
171+
// the state is not shared between replicas.
172+
type StickyLatencyBased struct {
173+
// LatencyBased provides the core latency-based selection functionality
174+
*LatencyBased
175+
// sessionMap maps session identifiers to server references for sticky sessions.
176+
sessionMap map[string]*Server
177+
// sessionMu is a separate mutex for session-specific operations to avoid lock contention
178+
sessionMu sync.RWMutex
179+
// sessionTimeout defines how long sessions are kept in memory.
180+
sessionTimeout time.Duration
181+
// sessionCleanupTicker periodically cleans up expired sessions.
182+
sessionCleanupTicker *time.Ticker
183+
// sessionTimestamps tracks when sessions were last accessed.
184+
sessionTimestamps map[string]time.Time
185+
}
186+
187+
// NewStickyLatencyBased returns a new StickyLatencyBased load balancer instance.
188+
// It embeds a LatencyBased load balancer and adds session management functionality.
189+
func NewStickyLatencyBased(log *slog.Logger, sessionTimeout time.Duration) *StickyLatencyBased {
190+
if sessionTimeout == 0 {
191+
sessionTimeout = 30 * time.Minute // Default session timeout
192+
}
193+
194+
slb := &StickyLatencyBased{
195+
LatencyBased: NewLatencyBased(log),
196+
sessionMap: make(map[string]*Server),
197+
sessionTimestamps: make(map[string]time.Time),
198+
sessionTimeout: sessionTimeout,
199+
}
200+
201+
// Start cleanup routine for expired sessions
202+
slb.sessionCleanupTicker = time.NewTicker(5 * time.Minute)
203+
go slb.cleanupExpiredSessions()
204+
205+
return slb
206+
}
207+
208+
// Next returns the next server based on session affinity and latency.
209+
// It first checks for existing session identifiers in headers or cookies,
210+
// then falls back to the embedded LatencyBased selection for new sessions.
211+
func (slb *StickyLatencyBased) Next(req *http.Request) *Server {
212+
if req == nil {
213+
slb.log.Warn("provided request is nil")
214+
return slb.LatencyBased.Next(nil)
215+
}
216+
217+
slb.LatencyBased.mu.Lock()
218+
if len(slb.LatencyBased.servers) == 0 {
219+
slb.LatencyBased.mu.Unlock()
220+
return nil
221+
}
222+
slb.LatencyBased.mu.Unlock()
223+
224+
sessionID := slb.extractSessionID(req)
225+
226+
if sessionID != "" {
227+
slb.sessionMu.RLock()
228+
if server, exists := slb.sessionMap[sessionID]; exists {
229+
// Check if session has timed out (cache miss scenario)
230+
lastAccessed := slb.sessionTimestamps[sessionID]
231+
if time.Since(lastAccessed) > slb.sessionTimeout {
232+
slb.sessionMu.RUnlock()
233+
234+
// Session timed out, clean it up
235+
slb.sessionMu.Lock()
236+
delete(slb.sessionMap, sessionID)
237+
delete(slb.sessionTimestamps, sessionID)
238+
slb.sessionMu.Unlock()
239+
240+
slb.log.Info("session timed out, removed",
241+
"session_id", sessionID,
242+
"server", server.name,
243+
"last_accessed", lastAccessed)
244+
} else {
245+
// Session is valid (cache hit scenario)
246+
slb.sessionMu.RUnlock()
247+
248+
// Update session timestamp
249+
slb.sessionMu.Lock()
250+
slb.sessionTimestamps[sessionID] = time.Now()
251+
slb.sessionMu.Unlock()
252+
return server
253+
}
254+
} else {
255+
slb.sessionMu.RUnlock()
256+
}
257+
}
258+
259+
// No existing session or unhealthy server, use embedded LatencyBased selection
260+
server := slb.LatencyBased.Next(req)
261+
262+
if server != nil && sessionID != "" {
263+
// Create new session mapping
264+
slb.sessionMu.Lock()
265+
slb.sessionMap[sessionID] = server
266+
slb.sessionTimestamps[sessionID] = time.Now()
267+
slb.sessionMu.Unlock()
268+
269+
slb.log.Debug("created new sticky session",
270+
"session_id", sessionID,
271+
"server", server.name)
272+
}
273+
274+
return server
275+
}
276+
277+
// extractSessionID extracts session identifier from HTTP request.
278+
// It only checks for the X-PROXY-KEY header. If not provided, returns empty string
279+
// which will cause the load balancer to use normal latency-based selection.
280+
func (slb *StickyLatencyBased) extractSessionID(req *http.Request) string {
281+
return req.Header.Get(ProxyKeyHeader)
282+
}
283+
284+
// Update updates the list of available servers using the embedded LatencyBased functionality
285+
// and cleans up session mappings for servers that no longer exist.
286+
func (slb *StickyLatencyBased) Update(servers []*Server) {
287+
slb.LatencyBased.Update(servers)
288+
}
289+
290+
// cleanupExpiredSessions runs in a background goroutine to clean up expired sessions
291+
func (slb *StickyLatencyBased) cleanupExpiredSessions() {
292+
for range slb.sessionCleanupTicker.C {
293+
slb.sessionMu.Lock()
294+
now := time.Now()
295+
for sessionID, timestamp := range slb.sessionTimestamps {
296+
if now.Sub(timestamp) > slb.sessionTimeout {
297+
delete(slb.sessionMap, sessionID)
298+
delete(slb.sessionTimestamps, sessionID)
299+
slb.log.Debug("cleaned up expired session", "session_id", sessionID)
300+
}
301+
}
302+
slb.sessionMu.Unlock()
303+
}
304+
}
305+
306+
// Stop stops the cleanup ticker and releases resources
307+
func (slb *StickyLatencyBased) Stop() {
308+
if slb.sessionCleanupTicker != nil {
309+
slb.sessionCleanupTicker.Stop()
310+
}
311+
}

0 commit comments

Comments
 (0)