Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ func runProxy(cfg config.Config) {
},
}
seeder := seed.New(seederCfg, log, rpcListener, restListener, grpcListener)
rpcProxyHandler := proxy.NewRPCProxy(rpcListener, cfg.Health, log, proxy.NewLatencyBased(log))
restProxyHandler := proxy.NewRestProxy(restListener, cfg.Health, log, proxy.NewLatencyBased(log))
grpcProxyHandler := proxy.NewGRPCProxy(grpcListener, log, proxy.NewLatencyBased(log))
rpcProxyHandler := proxy.NewRPCProxy(rpcListener, cfg.Health, log, proxy.NewStickyLatencyBased(log, 6*time.Second))
restProxyHandler := proxy.NewRestProxy(restListener, cfg.Health, log, proxy.NewStickyLatencyBased(log, 6*time.Second))
grpcProxyHandler := proxy.NewGRPCProxy(grpcListener, log, proxy.NewStickyLatencyBased(log, 6*time.Second))

ctx, proxyCtxCancel := context.WithCancel(context.Background())
defer proxyCtxCancel()
Expand Down Expand Up @@ -264,7 +264,7 @@ func runProxy(cfg config.Config) {
}

func main() {
var v = viper.New()
v := viper.New()

if err := NewRootCmd(v).Execute(); err != nil {
log.Fatalf("failed to execute command: %v", err)
Expand Down
158 changes: 154 additions & 4 deletions internal/proxy/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package proxy
import (
"log/slog"
"math/rand"
"net/http"
"sync"
"time"
)
Expand All @@ -11,14 +12,17 @@ import (
// used as toleration when comparing floats.
const epsilon = 1e-9

// ProxyKeyHeader is the HTTP header used for sticky session identification
const ProxyKeyHeader = "X-PROXY-KEY"

// LoadBalancer is an interface for load balancing algorithms. It provides
// methods to update the list of available servers and to select the next
// server to be used.
type LoadBalancer interface {
// Update updates the list of available servers.
Update([]*Server)
// Next returns the next server to be used based on the load balancing algorithm.
Next() *Server
Next(*http.Request) *Server
}

// RoundRobin is a simple load balancer that distributes incoming requests
Expand All @@ -43,7 +47,7 @@ func NewRoundRobin(log *slog.Logger) *RoundRobin {

// Next returns the next server to be used based on the round-robin algorithm.
// If the selected server is unhealthy, it will recursively try the next server.
func (rr *RoundRobin) Next() *Server {
func (rr *RoundRobin) Next(_ *http.Request) *Server {
rr.mu.Lock()
if len(rr.servers) == 0 {
return nil
Expand All @@ -56,7 +60,7 @@ func (rr *RoundRobin) Next() *Server {
return server
}
rr.log.Warn("server is unhealthy, trying next", "name", server.name)
return rr.Next()
return rr.Next(nil)
}

// Update updates the list of available servers.
Expand Down Expand Up @@ -104,7 +108,7 @@ func NewLatencyBased(log *slog.Logger) *LatencyBased {
// The random number will fall into one of these ranges, effectively selecting
// a server based on its latency rate. This approach works regardless of the order of
// the servers, so there's no need to sort them based on latency or rate.
func (rr *LatencyBased) Next() *Server {
func (rr *LatencyBased) Next(_ *http.Request) *Server {
rr.mu.Lock()
defer rr.mu.Unlock()

Expand Down Expand Up @@ -159,3 +163,149 @@ func (rr *LatencyBased) Update(servers []*Server) {
rr.servers[i].Rate /= totalInverse
}
}

// StickyLatencyBased is a load balancer that combines session affinity with latency-based routing.
// It embeds LatencyBased to reuse latency calculation and server management functionality,
// while adding session stickiness using industry-standard headers and cookies.
// Warning: This load balancer type is not effective if running alongside other replicas as
// the state is not shared between replicas.
type StickyLatencyBased struct {
// LatencyBased provides the core latency-based selection functionality
*LatencyBased
// sessionMap maps session identifiers to server references for sticky sessions.
sessionMap map[string]*Server
// sessionMu is a separate mutex for session-specific operations to avoid lock contention
sessionMu sync.RWMutex
// sessionTimeout defines how long sessions are kept in memory.
sessionTimeout time.Duration
// sessionCleanupTicker periodically cleans up expired sessions.
sessionCleanupTicker *time.Ticker
// sessionTimestamps tracks when sessions were last accessed.
sessionTimestamps map[string]time.Time
}

// NewStickyLatencyBased returns a new StickyLatencyBased load balancer instance.
// It embeds a LatencyBased load balancer and adds session management functionality.
func NewStickyLatencyBased(log *slog.Logger, sessionTimeout time.Duration) *StickyLatencyBased {
if sessionTimeout == 0 {
sessionTimeout = 30 * time.Minute // Default session timeout
}

slb := &StickyLatencyBased{
LatencyBased: NewLatencyBased(log),
sessionMap: make(map[string]*Server),
sessionTimestamps: make(map[string]time.Time),
sessionTimeout: sessionTimeout,
}

// Start cleanup routine for expired sessions
slb.sessionCleanupTicker = time.NewTicker(5 * time.Minute)
go slb.cleanupExpiredSessions()

return slb
}

// Next returns the next server based on session affinity and latency.
// It first checks for existing session identifiers in headers or cookies,
// then falls back to the embedded LatencyBased selection for new sessions.
func (slb *StickyLatencyBased) Next(req *http.Request) *Server {
if req == nil {
slb.log.Warn("provided request is nil")
return slb.LatencyBased.Next(nil)
}

slb.LatencyBased.mu.Lock()
if len(slb.LatencyBased.servers) == 0 {
slb.LatencyBased.mu.Unlock()
return nil
}
slb.LatencyBased.mu.Unlock()

sessionID := slb.extractSessionID(req)

if sessionID != "" {
slb.sessionMu.RLock()
if server, exists := slb.sessionMap[sessionID]; exists {
// Check if session has timed out (cache miss scenario)
lastAccessed := slb.sessionTimestamps[sessionID]
if time.Since(lastAccessed) > slb.sessionTimeout {
slb.sessionMu.RUnlock()

// Session timed out, clean it up
slb.sessionMu.Lock()
delete(slb.sessionMap, sessionID)
delete(slb.sessionTimestamps, sessionID)
slb.sessionMu.Unlock()

slb.log.Info("session timed out, removed",
"session_id", sessionID,
"server", server.name,
"last_accessed", lastAccessed)
} else {
// Session is valid (cache hit scenario)
slb.sessionMu.RUnlock()

// Update session timestamp
slb.sessionMu.Lock()
slb.sessionTimestamps[sessionID] = time.Now()
slb.sessionMu.Unlock()
return server
}
} else {
slb.sessionMu.RUnlock()
}
}

// No existing session or unhealthy server, use embedded LatencyBased selection
server := slb.LatencyBased.Next(req)

if server != nil && sessionID != "" {
// Create new session mapping
slb.sessionMu.Lock()
slb.sessionMap[sessionID] = server
slb.sessionTimestamps[sessionID] = time.Now()
slb.sessionMu.Unlock()

slb.log.Debug("created new sticky session",
"session_id", sessionID,
"server", server.name)
}

return server
}

// extractSessionID extracts session identifier from HTTP request.
// It only checks for the X-PROXY-KEY header. If not provided, returns empty string
// which will cause the load balancer to use normal latency-based selection.
func (slb *StickyLatencyBased) extractSessionID(req *http.Request) string {
return req.Header.Get(ProxyKeyHeader)
}

// Update updates the list of available servers using the embedded LatencyBased functionality
// and cleans up session mappings for servers that no longer exist.
func (slb *StickyLatencyBased) Update(servers []*Server) {
slb.LatencyBased.Update(servers)
}

// cleanupExpiredSessions runs in a background goroutine to clean up expired sessions
func (slb *StickyLatencyBased) cleanupExpiredSessions() {
for range slb.sessionCleanupTicker.C {
slb.sessionMu.Lock()
now := time.Now()
for sessionID, timestamp := range slb.sessionTimestamps {
if now.Sub(timestamp) > slb.sessionTimeout {
delete(slb.sessionMap, sessionID)
delete(slb.sessionTimestamps, sessionID)
slb.log.Debug("cleaned up expired session", "session_id", sessionID)
}
}
slb.sessionMu.Unlock()
}
}

// Stop stops the cleanup ticker and releases resources
func (slb *StickyLatencyBased) Stop() {
if slb.sessionCleanupTicker != nil {
slb.sessionCleanupTicker.Stop()
}
}
Loading
Loading