diff --git a/config_resolution.go b/config_resolution.go index 200c376..4a6c4b4 100644 --- a/config_resolution.go +++ b/config_resolution.go @@ -28,6 +28,7 @@ type configCLIInputs struct { MemoryRebalance bool MaxWorkers int MinWorkers int + WorkerQueueTimeout string ACMEDomain string ACMEEmail string ACMECacheDir string @@ -35,7 +36,8 @@ type configCLIInputs struct { } type resolvedConfig struct { - Server server.Config + Server server.Config + WorkerQueueTimeout time.Duration } func defaultServerConfig() server.Config { @@ -69,6 +71,7 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun } cfg := defaultServerConfig() + var workerQueueTimeout time.Duration if fileCfg != nil { if fileCfg.Host != "" { @@ -210,6 +213,13 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun if fileCfg.MinWorkers != 0 { cfg.MinWorkers = fileCfg.MinWorkers } + if fileCfg.WorkerQueueTimeout != "" { + if d, err := time.ParseDuration(fileCfg.WorkerQueueTimeout); err == nil { + workerQueueTimeout = d + } else { + warn("Invalid worker_queue_timeout duration: " + err.Error()) + } + } if len(fileCfg.PassthroughUsers) > 0 { cfg.PassthroughUsers = make(map[string]bool, len(fileCfg.PassthroughUsers)) for _, u := range fileCfg.PassthroughUsers { @@ -365,6 +375,13 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun warn("Invalid DUCKGRES_MAX_WORKERS: " + err.Error()) } } + if v := getenv("DUCKGRES_WORKER_QUEUE_TIMEOUT"); v != "" { + if d, err := time.ParseDuration(v); err == nil { + workerQueueTimeout = d + } else { + warn("Invalid DUCKGRES_WORKER_QUEUE_TIMEOUT duration: " + err.Error()) + } + } if v := getenv("DUCKGRES_ACME_DOMAIN"); v != "" { cfg.ACMEDomain = v } @@ -456,6 +473,13 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun if cli.Set["max-workers"] { cfg.MaxWorkers = cli.MaxWorkers } + if cli.Set["worker-queue-timeout"] { + if d, err := time.ParseDuration(cli.WorkerQueueTimeout); err == nil { + workerQueueTimeout = d + } else { + warn("Invalid --worker-queue-timeout duration: " + err.Error()) + } + } if cli.Set["acme-domain"] { cfg.ACMEDomain = cli.ACMEDomain } @@ -482,6 +506,7 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun } return resolvedConfig{ - Server: cfg, + Server: cfg, + WorkerQueueTimeout: workerQueueTimeout, } } diff --git a/controlplane/control.go b/controlplane/control.go index cfd20f7..637dac0 100644 --- a/controlplane/control.go +++ b/controlplane/control.go @@ -28,6 +28,7 @@ type ControlPlaneConfig struct { ConfigPath string // Path to config file, passed to workers HandoverSocket string HealthCheckInterval time.Duration + WorkerQueueTimeout time.Duration // How long to wait for an available worker slot (default: 5m) } // ControlPlane manages the TCP listener and routes connections to Flight SQL workers. @@ -62,6 +63,9 @@ func RunControlPlane(cfg ControlPlaneConfig) { if cfg.HealthCheckInterval == 0 { cfg.HealthCheckInterval = 2 * time.Second } + if cfg.WorkerQueueTimeout == 0 { + cfg.WorkerQueueTimeout = 5 * time.Minute + } // Enforce secure defaults for control-plane mode. if err := validateControlPlaneSecurity(cfg); err != nil { @@ -153,7 +157,6 @@ func RunControlPlane(cfg ControlPlaneConfig) { // Lock ordering invariant: rebalancer.mu → sm.mu(RLock). Never acquire // rebalancer.mu while holding sm.mu to avoid deadlock. rebalancer.SetSessionLister(sessions) - pool.SetSessionCounter(sessions) cp := &ControlPlane{ cfg: cfg, @@ -171,10 +174,11 @@ func RunControlPlane(cfg ControlPlaneConfig) { // creation races between pre-warm and first external Flight requests. if cfg.FlightPort > 0 { flightIngress, err := NewFlightIngress(cfg.Host, cfg.FlightPort, tlsCfg, cfg.Users, sessions, cp.rateLimiter, FlightIngressConfig{ - SessionIdleTTL: cfg.FlightSessionIdleTTL, - SessionReapTick: cfg.FlightSessionReapInterval, - HandleIdleTTL: cfg.FlightHandleIdleTTL, - SessionTokenTTL: cfg.FlightSessionTokenTTL, + SessionIdleTTL: cfg.FlightSessionIdleTTL, + SessionReapTick: cfg.FlightSessionReapInterval, + HandleIdleTTL: cfg.FlightHandleIdleTTL, + SessionTokenTTL: cfg.FlightSessionTokenTTL, + WorkerQueueTimeout: cfg.WorkerQueueTimeout, }) if err != nil { slog.Error("Failed to initialize Flight ingress.", "error", err) @@ -261,6 +265,7 @@ func RunControlPlane(cfg ControlPlaneConfig) { "flight_addr", cp.flightAddr(), "min_workers", minWorkers, "max_workers", maxWorkers, + "worker_queue_timeout", cfg.WorkerQueueTimeout, "memory_budget", formatBytes(rebalancer.memoryBudget), "memory_rebalance", cfg.MemoryRebalance) @@ -469,9 +474,11 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) { server.RecordSuccessfulAuthAttempt(cp.rateLimiter, remoteAddr) slog.Info("User authenticated.", "user", username, "remote_addr", remoteAddr) - // Create session on a worker - ctx := context.Background() + // Create session on a worker. The timeout controls how long we wait in the + // worker queue when all slots are occupied. + ctx, cancel := context.WithTimeout(context.Background(), cp.cfg.WorkerQueueTimeout) pid, executor, err := cp.sessions.CreateSession(ctx, username) + cancel() if err != nil { slog.Error("Failed to create session.", "user", username, "remote_addr", remoteAddr, "error", err) _ = server.WriteErrorResponse(writer, "FATAL", "53300", "too many connections") @@ -732,9 +739,11 @@ func (cp *ControlPlane) recoverFlightIngressAfterFailedReload() { } flightIngress, err := NewFlightIngress(cp.cfg.Host, cp.cfg.FlightPort, cp.tlsConfig, cp.cfg.Users, cp.sessions, cp.rateLimiter, FlightIngressConfig{ - SessionIdleTTL: cp.cfg.FlightSessionIdleTTL, - SessionReapTick: cp.cfg.FlightSessionReapInterval, - HandleIdleTTL: cp.cfg.FlightHandleIdleTTL, + SessionIdleTTL: cp.cfg.FlightSessionIdleTTL, + SessionReapTick: cp.cfg.FlightSessionReapInterval, + HandleIdleTTL: cp.cfg.FlightHandleIdleTTL, + SessionTokenTTL: cp.cfg.FlightSessionTokenTTL, + WorkerQueueTimeout: cp.cfg.WorkerQueueTimeout, }) if err != nil { slog.Error("Failed to recover Flight ingress after reload failure.", "error", err) diff --git a/controlplane/flight_ingress.go b/controlplane/flight_ingress.go index 74bd733..7a40a16 100644 --- a/controlplane/flight_ingress.go +++ b/controlplane/flight_ingress.go @@ -2,7 +2,6 @@ package controlplane import ( "crypto/tls" - "errors" "github.com/posthog/duckgres/server" "github.com/posthog/duckgres/server/flightsqlingress" @@ -15,14 +14,10 @@ type FlightIngress = flightsqlingress.FlightIngress // NewFlightIngress creates a control-plane Flight SQL ingress listener. func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[string]string, sm *SessionManager, rateLimiter *server.RateLimiter, cfg FlightIngressConfig) (*FlightIngress, error) { return flightsqlingress.NewFlightIngress(host, port, tlsConfig, users, sm, cfg, flightsqlingress.Options{ - IsMaxWorkersError: func(err error) bool { - return errors.Is(err, ErrMaxWorkersReached) - }, RateLimiter: rateLimiter, Hooks: flightsqlingress.Hooks{ OnSessionCountChanged: observeFlightAuthSessions, OnSessionsReaped: observeFlightSessionsReaped, - OnMaxWorkersRetry: observeFlightMaxWorkersRetry, }, }) } diff --git a/controlplane/flight_ingress_metrics.go b/controlplane/flight_ingress_metrics.go index 33cbbee..1608cc1 100644 --- a/controlplane/flight_ingress_metrics.go +++ b/controlplane/flight_ingress_metrics.go @@ -18,11 +18,6 @@ var flightSessionsReapedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ Help: "Number of Flight auth sessions reaped", }, []string{"trigger"}) -var flightMaxWorkersRetryCounter = promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "duckgres_flight_max_workers_retry_total", - Help: "Number of max-worker retry outcomes when creating Flight auth sessions", -}, []string{"outcome"}) - func observeFlightAuthSessions(count int) { if count < 0 { count = 0 @@ -43,7 +38,3 @@ func observeFlightSessionsReaped(trigger string, count int) { } flightSessionsReapedCounter.WithLabelValues(trigger).Add(float64(count)) } - -func observeFlightMaxWorkersRetry(outcome string) { - flightMaxWorkersRetryCounter.WithLabelValues(outcome).Inc() -} diff --git a/controlplane/memory_rebalancer.go b/controlplane/memory_rebalancer.go index 97a402c..48cafd3 100644 --- a/controlplane/memory_rebalancer.go +++ b/controlplane/memory_rebalancer.go @@ -208,9 +208,9 @@ func memoryLimit(budget uint64) uint64 { } // DefaultMaxWorkers returns a reasonable default for max_workers. -// Defaults to number of CPUs * 2. +// Derived from the memory budget (budget / 256MB). func (r *MemoryRebalancer) DefaultMaxWorkers() int { - return runtime.NumCPU() * 2 + return int(r.memoryBudget / minMemoryPerSession) } // SetInitialLimits sets memory_limit and threads on a single session synchronously. diff --git a/controlplane/session_mgr.go b/controlplane/session_mgr.go index 39af506..a6089a5 100644 --- a/controlplane/session_mgr.go +++ b/controlplane/session_mgr.go @@ -48,15 +48,16 @@ func NewSessionManager(pool *FlightWorkerPool, rebalancer *MemoryRebalancer) *Se // creates a session on it, and rebalances memory/thread limits across all active sessions. func (sm *SessionManager) CreateSession(ctx context.Context, username string) (int32, *server.FlightExecutor, error) { // Acquire a worker: reuses idle pre-warmed workers or spawns a new one. - // Max-workers check is atomic inside AcquireWorker to prevent TOCTOU races. - worker, err := sm.pool.AcquireWorker() + // When max-workers is set, this blocks until a slot is available. + worker, err := sm.pool.AcquireWorker(ctx) if err != nil { return 0, nil, fmt.Errorf("acquire worker: %w", err) } sessionToken, err := worker.CreateSession(ctx, username) if err != nil { - // Clean up the worker we just spawned (but not if it was a pre-warmed idle worker) + // Clean up the worker we just spawned (but not if it was a pre-warmed idle worker + // that has sessions from other concurrent requests). sm.pool.RetireWorkerIfNoSessions(worker.ID) return 0, nil, fmt.Errorf("create session on worker %d: %w", worker.ID, err) } diff --git a/controlplane/worker_mgr.go b/controlplane/worker_mgr.go index 0033219..796243d 100644 --- a/controlplane/worker_mgr.go +++ b/controlplane/worker_mgr.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "encoding/hex" "encoding/json" - "errors" "fmt" "log/slog" "os" @@ -22,39 +21,29 @@ import ( "google.golang.org/grpc/credentials/insecure" ) -var ErrMaxWorkersReached = errors.New("max workers reached") - // ManagedWorker represents a duckdb-service worker process. type ManagedWorker struct { - ID int - cmd *exec.Cmd - socketPath string - bearerToken string - client *flightsql.Client - done chan struct{} // closed when process exits - exitErr error -} - -// SessionCounter provides session counts per worker for load balancing. -type SessionCounter interface { - SessionCountForWorker(workerID int) int + ID int + cmd *exec.Cmd + socketPath string + bearerToken string + client *flightsql.Client + done chan struct{} // closed when process exits + exitErr error + activeSessions int // Number of sessions currently assigned to this worker } -// FlightWorkerPool manages a pool of duckdb-service worker processes. -// -// Lock ordering invariant: pool.mu → SessionManager.mu(RLock). -// findIdleWorkerLocked calls SessionCountForWorker while holding pool.mu. -// Never acquire pool.mu while holding SessionManager.mu to avoid deadlock. type FlightWorkerPool struct { - mu sync.RWMutex - workers map[int]*ManagedWorker - nextWorkerID int // auto-incrementing worker ID - socketDir string - configPath string - binaryPath string - sessionCounter SessionCounter // set after SessionManager is created - maxWorkers int // 0 = unlimited - shuttingDown bool + mu sync.RWMutex + workers map[int]*ManagedWorker + nextWorkerID int // auto-incrementing worker ID + socketDir string + configPath string + binaryPath string + maxWorkers int // 0 = unlimited + shuttingDown bool + workerSem chan struct{} // buffered to maxWorkers; limits concurrent acquisitions + shutdownCh chan struct{} // closed by ShutdownAll to unblock queued waiters } // NewFlightWorkerPool creates a new worker pool. @@ -66,6 +55,10 @@ func NewFlightWorkerPool(socketDir, configPath string, maxWorkers int) *FlightWo configPath: configPath, binaryPath: binaryPath, maxWorkers: maxWorkers, + shutdownCh: make(chan struct{}), + } + if maxWorkers > 0 { + pool.workerSem = make(chan struct{}, maxWorkers) } observeControlPlaneWorkers(0) return pool @@ -199,12 +192,6 @@ func doHealthCheck(ctx context.Context, client *flightsql.Client) error { return nil } -// SetSessionCounter sets the session counter for load balancing. -// Must be called before accepting connections. -func (p *FlightWorkerPool) SetSessionCounter(sc SessionCounter) { - p.sessionCounter = sc -} - // Worker returns a worker by ID. func (p *FlightWorkerPool) Worker(id int) (*ManagedWorker, bool) { p.mu.RLock() @@ -235,33 +222,36 @@ func (p *FlightWorkerPool) SpawnMinWorkers(count int) error { return p.SpawnAll(count) } -// AcquireWorker returns a worker for a new session. It first tries to claim an -// idle pre-warmed worker (one with no active sessions). If none are available, -// it spawns a new one. The max-workers check is performed atomically under the -// write lock to prevent TOCTOU races from concurrent connections. -func (p *FlightWorkerPool) AcquireWorker() (*ManagedWorker, error) { +// AcquireWorker returns a worker for a new session. When maxWorkers is set, +// callers block in FIFO order on the semaphore until a slot is available, +// the context is cancelled, or the pool shuts down. +// Once a slot is acquired, it first tries to claim an idle pre-warmed worker +// (one with no active sessions). If none are available, it spawns a new one. +func (p *FlightWorkerPool) AcquireWorker(ctx context.Context) (*ManagedWorker, error) { + // Block until a semaphore slot is available (FIFO via Go's sudog queue). + if p.workerSem != nil { + select { + case p.workerSem <- struct{}{}: + // Got a slot + case <-ctx.Done(): + return nil, fmt.Errorf("timed out waiting for available worker (max_workers=%d): %w", p.maxWorkers, ctx.Err()) + case <-p.shutdownCh: + return nil, fmt.Errorf("pool is shutting down") + } + } + p.mu.Lock() if p.shuttingDown { p.mu.Unlock() + p.releaseWorkerSem() return nil, fmt.Errorf("pool is shutting down") } - // Check max-workers cap atomically under the write lock - if p.maxWorkers > 0 && len(p.workers) >= p.maxWorkers { - // Even at the cap, we may have idle pre-warmed workers to reuse. - // Only fail if all existing workers are busy. - idle := p.findIdleWorkerLocked() - if idle != nil { - p.mu.Unlock() - return idle, nil - } - p.mu.Unlock() - return nil, fmt.Errorf("%w (%d)", ErrMaxWorkersReached, p.maxWorkers) - } - - // Try to claim an idle pre-warmed worker before spawning a new one + // Try to claim an idle pre-warmed worker before spawning a new one. + // Atomic claim: increment activeSessions while holding the lock. idle := p.findIdleWorkerLocked() if idle != nil { + idle.activeSessions++ p.mu.Unlock() return idle, nil } @@ -271,16 +261,33 @@ func (p *FlightWorkerPool) AcquireWorker() (*ManagedWorker, error) { p.mu.Unlock() if err := p.SpawnWorker(id); err != nil { + p.releaseWorkerSem() return nil, err } w, ok := p.Worker(id) if !ok { + p.releaseWorkerSem() return nil, fmt.Errorf("worker %d not found after spawn", id) } + + p.mu.Lock() + w.activeSessions++ + p.mu.Unlock() + return w, nil } +// releaseWorkerSem drains one token from the semaphore (non-blocking). +func (p *FlightWorkerPool) releaseWorkerSem() { + if p.workerSem != nil { + select { + case <-p.workerSem: + default: + } + } +} + // findIdleWorkerLocked returns a live worker with no active sessions, or nil. // Caller must hold p.mu (read or write lock). func (p *FlightWorkerPool) findIdleWorkerLocked() *ManagedWorker { @@ -290,7 +297,7 @@ func (p *FlightWorkerPool) findIdleWorkerLocked() *ManagedWorker { continue // dead default: } - if p.sessionCounter != nil && p.sessionCounter.SessionCountForWorker(w.ID) == 0 { + if w.activeSessions == 0 { return w } } @@ -308,10 +315,16 @@ func (p *FlightWorkerPool) RetireWorker(id int) { return } delete(p.workers, id) + sessions := w.activeSessions workerCount := len(p.workers) p.mu.Unlock() observeControlPlaneWorkers(workerCount) + // Release semaphore slots so queued waiters can proceed. + for i := 0; i < sessions; i++ { + p.releaseWorkerSem() + } + // Run the actual process cleanup asynchronously so DestroySession // doesn't block the connection handler goroutine for up to 3s+. go retireWorkerProcess(w) @@ -319,11 +332,32 @@ func (p *FlightWorkerPool) RetireWorker(id int) { // RetireWorkerIfNoSessions retires a worker only if it has no active sessions. // Used to clean up on session creation failure without retiring pre-warmed workers. -func (p *FlightWorkerPool) RetireWorkerIfNoSessions(id int) { - if p.sessionCounter != nil && p.sessionCounter.SessionCountForWorker(id) > 0 { - return +// Returns true if the worker was retired (and its semaphore slot released). +func (p *FlightWorkerPool) RetireWorkerIfNoSessions(id int) bool { + p.mu.Lock() + w, ok := p.workers[id] + if !ok { + p.mu.Unlock() + return false + } + + // Decrement the acquisition claim we just made. + if w.activeSessions > 0 { + w.activeSessions-- + p.releaseWorkerSem() + } + + // If it has NO other active sessions, kill it to be safe (it might be broken). + if w.activeSessions == 0 { + delete(p.workers, id) + workerCount := len(p.workers) + p.mu.Unlock() + observeControlPlaneWorkers(workerCount) + go retireWorkerProcess(w) + return true } - p.RetireWorker(id) + p.mu.Unlock() + return false } // retireWorkerProcess handles the actual process shutdown and socket cleanup. @@ -384,6 +418,9 @@ func (p *FlightWorkerPool) ShutdownAll() { } p.mu.Unlock() + // Unblock all goroutines waiting in AcquireWorker's semaphore select. + close(p.shutdownCh) + for _, w := range workers { if w.cmd.Process != nil { slog.Info("Shutting down worker.", "id", w.ID, "pid", w.cmd.Process.Pid) @@ -483,6 +520,12 @@ func (p *FlightWorkerPool) HealthCheckLoop(ctx context.Context, interval time.Du _ = w.client.Close() } _ = os.Remove(w.socketPath) + + // Release as many semaphore slots as this worker was using. + // In the 1:1 model, this is usually 1. + for i := 0; i < w.activeSessions; i++ { + p.releaseWorkerSem() + } default: // Worker is alive, do a health check. // Recover nil-pointer panics: w.client.Close() (from a @@ -535,6 +578,10 @@ func (p *FlightWorkerPool) HealthCheckLoop(ctx context.Context, interval time.Du _ = w.client.Close() } _ = os.Remove(w.socketPath) + + for i := 0; i < w.activeSessions; i++ { + p.releaseWorkerSem() + } } } } else { diff --git a/controlplane/worker_mgr_test.go b/controlplane/worker_mgr_test.go index 56202c9..bbf2774 100644 --- a/controlplane/worker_mgr_test.go +++ b/controlplane/worker_mgr_test.go @@ -3,6 +3,7 @@ package controlplane import ( "context" "os/exec" + "sync" "testing" "time" ) @@ -230,3 +231,299 @@ func TestHealthCheckLoopDetectsCrashedWorker(t *testing.T) { t.Fatal("crashed worker should have been removed from pool") } } + +// makeFakeWorker creates a ManagedWorker with a started process that stays alive. +// Returns the worker and a cancel function to kill it. +func makeFakeWorker(t *testing.T, id int) (*ManagedWorker, func()) { + t.Helper() + cmd := exec.Command("sleep", "60") + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start fake worker process: %v", err) + } + done := make(chan struct{}) + w := &ManagedWorker{ + ID: id, + cmd: cmd, + done: done, + } + go func() { + w.exitErr = cmd.Wait() + close(done) + }() + cleanup := func() { + _ = cmd.Process.Kill() + <-done + } + return w, cleanup +} + +func TestAcquireWorkerBlocksUntilSlotAvailable(t *testing.T) { + pool := NewFlightWorkerPool(t.TempDir(), "", 2) + + // Pre-populate 2 busy workers so the pool is at capacity. + w0, cleanup0 := makeFakeWorker(t, 0) + defer cleanup0() + w1, cleanup1 := makeFakeWorker(t, 1) + defer cleanup1() + + pool.mu.Lock() + w0.activeSessions = 1 + w1.activeSessions = 1 + pool.workers[0] = w0 + pool.workers[1] = w1 + pool.nextWorkerID = 2 + // Fill the semaphore to match the 2 active workers. + pool.workerSem <- struct{}{} + pool.workerSem <- struct{}{} + pool.mu.Unlock() + + // AcquireWorker should now block because the semaphore is full. + acquired := make(chan struct{}) + go func() { + // This will block until a slot opens. + _, _ = pool.AcquireWorker(context.Background()) + close(acquired) + }() + + // Verify it doesn't return immediately. + select { + case <-acquired: + t.Fatal("AcquireWorker should block when pool is at capacity") + case <-time.After(100 * time.Millisecond): + // expected: still blocked + } + + // Retire one worker to free a slot. + pool.RetireWorker(0) + + // Now AcquireWorker should unblock (it will try to spawn, which may fail, + // but the point is it unblocked from the semaphore). + select { + case <-acquired: + // expected: unblocked + case <-time.After(15 * time.Second): + t.Fatal("AcquireWorker did not unblock after RetireWorker") + } +} + +func TestAcquireWorkerRespectsContextCancellation(t *testing.T) { + pool := NewFlightWorkerPool(t.TempDir(), "", 1) + + // Fill the single semaphore slot. + pool.workerSem <- struct{}{} + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err := pool.AcquireWorker(ctx) + if err == nil { + t.Fatal("expected error from cancelled context") + } + if ctx.Err() == nil { + t.Fatal("expected context to be done") + } +} + +func TestAcquireWorkerUnlimitedWhenMaxZero(t *testing.T) { + pool := NewFlightWorkerPool(t.TempDir(), "", 0) + + if pool.workerSem != nil { + t.Fatal("expected nil workerSem when maxWorkers=0") + } + + // AcquireWorker should not block on semaphore (it will fail trying to + // spawn a worker with a fake binary, but should get past the semaphore). + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := pool.AcquireWorker(ctx) + // Will fail to spawn since there's no real binary, but the point is + // it didn't block on a nil semaphore. + if err == nil { + t.Fatal("expected spawn error with non-existent binary") + } +} + +func TestAcquireWorkerShutdownUnblocksWaiters(t *testing.T) { + pool := NewFlightWorkerPool(t.TempDir(), "", 1) + + // Fill the single semaphore slot. + pool.workerSem <- struct{}{} + + errCh := make(chan error, 1) + go func() { + _, err := pool.AcquireWorker(context.Background()) + errCh <- err + }() + + // Give the goroutine time to block on the semaphore. + time.Sleep(50 * time.Millisecond) + + pool.ShutdownAll() + + select { + case err := <-errCh: + if err == nil { + t.Fatal("expected error after shutdown") + } + case <-time.After(3 * time.Second): + t.Fatal("AcquireWorker did not unblock after ShutdownAll") + } +} + +func TestRetireWorkerIfNoSessions_ReleasesClaimOnFailure(t *testing.T) { + pool := NewFlightWorkerPool(t.TempDir(), "", 1) + + // Manually inject a worker with 1 active session (as if AcquireWorker just returned it) + w, cleanup := makeFakeWorker(t, 1) + defer cleanup() + w.activeSessions = 1 + + pool.mu.Lock() + pool.workers[1] = w + pool.workerSem <- struct{}{} // Claim slot + pool.mu.Unlock() + + // Calling RetireWorkerIfNoSessions should release the claim and kill the worker. + if !pool.RetireWorkerIfNoSessions(1) { + t.Fatal("expected RetireWorkerIfNoSessions to return true") + } + + // Verify semaphore is freed: we should be able to push a token. + select { + case pool.workerSem <- struct{}{}: + // success + case <-time.After(100 * time.Millisecond): + t.Fatal("semaphore slot was leaked") + } + + // Verify worker is gone + if _, ok := pool.Worker(1); ok { + t.Fatal("worker should have been retired") + } +} + +func TestAcquireWorker_AtomicClaimRace(t *testing.T) { + // Tests that two concurrent acquisitions don't pick the same idle worker. + const n = 5 + pool := NewFlightWorkerPool(t.TempDir(), "", 10) + + // Pre-warm with n idle workers + for i := 1; i <= n; i++ { + w, cleanup := makeFakeWorker(t, i) + defer cleanup() + pool.mu.Lock() + pool.workers[i] = w + pool.mu.Unlock() + } + + // Simultaneous acquisitions + results := make(chan *ManagedWorker, n) + for i := 0; i < n; i++ { + go func() { + w, _ := pool.AcquireWorker(context.Background()) + results <- w + }() + } + + workers := make(map[int]bool) + for i := 0; i < n; i++ { + w := <-results + if w == nil { + t.Fatal("failed to acquire worker") + } + if workers[w.ID] { + t.Errorf("worker %d was assigned multiple times!", w.ID) + } + workers[w.ID] = true + } +} + +func TestRetireWorker_ReleasesAllSessions(t *testing.T) { + pool := NewFlightWorkerPool(t.TempDir(), "", 5) + + // Inject a worker with 2 sessions + w, cleanup := makeFakeWorker(t, 1) + defer cleanup() + w.activeSessions = 2 + + pool.mu.Lock() + pool.workers[1] = w + pool.workerSem <- struct{}{} + pool.workerSem <- struct{}{} + pool.mu.Unlock() + + // Retire it + pool.RetireWorker(1) + + // Verify 2 slots released: should be able to push 5 tokens (capacity) + for i := 0; i < 5; i++ { + select { + case pool.workerSem <- struct{}{}: + case <-time.After(100 * time.Millisecond): + t.Fatalf("semaphore slot %d was leaked", i) + } + } +} + +func TestCrashReleasesSemaphoreSlots(t *testing.T) { + pool := NewFlightWorkerPool(t.TempDir(), "", 2) + + // Create a worker that exits immediately (simulates crash). + cmd := exec.Command("true") + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start test process: %v", err) + } + done := make(chan struct{}) + w := &ManagedWorker{ + ID: 0, + cmd: cmd, + done: done, + } + go func() { + w.exitErr = cmd.Wait() + close(done) + }() + <-done // wait for it to exit + + pool.mu.Lock() + w.activeSessions = 1 + pool.workers[0] = w + pool.nextWorkerID = 1 + pool.workerSem <- struct{}{} // account for the worker in the semaphore + pool.mu.Unlock() + + // Start health check loop which will detect the crash and release the slot. + crashCh := make(chan int, 1) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go pool.HealthCheckLoop(ctx, 50*time.Millisecond, func(workerID int) { + select { + case crashCh <- workerID: + default: + } + }) + + // Wait for crash to be detected. + select { + case <-crashCh: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for crash detection") + } + + // Verify the semaphore slot was released: we should be able to push 2 tokens + // (maxWorkers=2) since the crashed worker's slot was freed. + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + select { + case pool.workerSem <- struct{}{}: + case <-time.After(2 * time.Second): + t.Error("semaphore slot not available after crash") + } + }() + } + wg.Wait() +} diff --git a/main.go b/main.go index 813bd0e..b2a0c57 100644 --- a/main.go +++ b/main.go @@ -40,8 +40,9 @@ type FileConfig struct { Threads int `yaml:"threads"` // DuckDB threads per session MemoryBudget string `yaml:"memory_budget"` // Total memory for all sessions (e.g., "24GB") MemoryRebalance *bool `yaml:"memory_rebalance"` // Enable dynamic per-connection memory reallocation - MaxWorkers int `yaml:"max_workers"` // Max worker processes (control-plane mode) - MinWorkers int `yaml:"min_workers"` // Pre-warm worker count (control-plane mode) + MaxWorkers int `yaml:"max_workers"` // Max worker processes (control-plane mode) + MinWorkers int `yaml:"min_workers"` // Pre-warm worker count (control-plane mode) + WorkerQueueTimeout string `yaml:"worker_queue_timeout"` // e.g., "5m" PassthroughUsers []string `yaml:"passthrough_users"` // Users that bypass transpiler + pg_catalog } @@ -169,6 +170,7 @@ func main() { mode := flag.String("mode", "standalone", "Run mode: standalone, control-plane, or duckdb-service") minWorkers := flag.Int("min-workers", 0, "Pre-warm worker count at startup (control-plane mode) (env: DUCKGRES_MIN_WORKERS)") maxWorkers := flag.Int("max-workers", 0, "Max worker processes, 0=unlimited (control-plane mode) (env: DUCKGRES_MAX_WORKERS)") + workerQueueTimeout := flag.String("worker-queue-timeout", "", "How long to wait for an available worker slot (e.g., '5m') (env: DUCKGRES_WORKER_QUEUE_TIMEOUT)") socketDir := flag.String("socket-dir", "/var/run/duckgres", "Unix socket directory (control-plane mode)") handoverSocket := flag.String("handover-socket", "", "Handover socket for graceful deployment (control-plane mode)") @@ -207,6 +209,7 @@ func main() { fmt.Fprintf(os.Stderr, " DUCKGRES_MEMORY_REBALANCE Enable dynamic per-connection memory reallocation (1 or true)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_MIN_WORKERS Pre-warm worker count (control-plane mode)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_MAX_WORKERS Max worker processes (control-plane mode)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_WORKER_QUEUE_TIMEOUT Worker queue timeout (default: 5m)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_ACME_DOMAIN Domain for ACME/Let's Encrypt certificate\n") fmt.Fprintf(os.Stderr, " DUCKGRES_ACME_EMAIL Contact email for Let's Encrypt notifications\n") fmt.Fprintf(os.Stderr, " DUCKGRES_ACME_CACHE_DIR Directory for ACME certificate cache\n") @@ -291,6 +294,7 @@ func main() { MemoryRebalance: *memoryRebalance, MinWorkers: *minWorkers, MaxWorkers: *maxWorkers, + WorkerQueueTimeout: *workerQueueTimeout, ACMEDomain: *acmeDomain, ACMEEmail: *acmeEmail, ACMECacheDir: *acmeCacheDir, @@ -409,10 +413,11 @@ func main() { // Handle control-plane mode if *mode == "control-plane" { cpCfg := controlplane.ControlPlaneConfig{ - Config: cfg, - SocketDir: *socketDir, - ConfigPath: *configFile, - HandoverSocket: *handoverSocket, + Config: cfg, + SocketDir: *socketDir, + ConfigPath: *configFile, + HandoverSocket: *handoverSocket, + WorkerQueueTimeout: resolved.WorkerQueueTimeout, } controlplane.RunControlPlane(cpCfg) return diff --git a/server/flightsqlingress/ingress.go b/server/flightsqlingress/ingress.go index 37ad877..fdb4608 100644 --- a/server/flightsqlingress/ingress.go +++ b/server/flightsqlingress/ingress.go @@ -42,17 +42,14 @@ const ( const ( ReapTriggerPeriodic = "periodic" ReapTriggerForced = "forced" - - MaxWorkersRetryAttempted = "attempted" - MaxWorkersRetrySucceeded = "succeeded" - MaxWorkersRetryFailed = "failed" ) type Config struct { - SessionIdleTTL time.Duration - SessionReapTick time.Duration - HandleIdleTTL time.Duration - SessionTokenTTL time.Duration + SessionIdleTTL time.Duration + SessionReapTick time.Duration + HandleIdleTTL time.Duration + SessionTokenTTL time.Duration + WorkerQueueTimeout time.Duration // applied to CreateSession calls; 0 = use request context as-is } type SessionProvider interface { @@ -63,13 +60,11 @@ type SessionProvider interface { type Hooks struct { OnSessionCountChanged func(int) OnSessionsReaped func(trigger string, count int) - OnMaxWorkersRetry func(outcome string) } type Options struct { - IsMaxWorkersError func(error) bool - Hooks Hooks - RateLimiter *server.RateLimiter + Hooks Hooks + RateLimiter *server.RateLimiter } // FlightIngress serves Arrow Flight SQL on the control plane with Basic auth. @@ -117,7 +112,7 @@ func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[st cfg.SessionTokenTTL = defaultFlightSessionTokenTTL } - store := newFlightAuthSessionStore(provider, cfg.SessionIdleTTL, cfg.SessionReapTick, cfg.HandleIdleTTL, cfg.SessionTokenTTL, opts) + store := newFlightAuthSessionStore(provider, cfg.SessionIdleTTL, cfg.SessionReapTick, cfg.HandleIdleTTL, cfg.SessionTokenTTL, cfg.WorkerQueueTimeout, opts) handler, err := NewControlPlaneFlightSQLHandler(store, users) if err != nil { _ = ln.Close() @@ -1029,13 +1024,13 @@ func (s *flightClientSession) reapStaleHandles(now time.Time, ttl time.Duration) } type flightAuthSessionStore struct { - provider SessionProvider - idleTTL time.Duration - reapInterval time.Duration - handleIdleTTL time.Duration - tokenTTL time.Duration - hooks Hooks - isMaxWorkerFn func(error) bool + provider SessionProvider + idleTTL time.Duration + reapInterval time.Duration + handleIdleTTL time.Duration + tokenTTL time.Duration + workerQueueTimeout time.Duration + hooks Hooks createSessionFn func(context.Context, string) (int32, *server.FlightExecutor, error) destroySessionFn func(int32) @@ -1061,7 +1056,7 @@ func (r *lockedRowSet) Close() error { return err } -func newFlightAuthSessionStore(provider SessionProvider, idleTTL, reapInterval, handleIdleTTL, tokenTTL time.Duration, opts Options) *flightAuthSessionStore { +func newFlightAuthSessionStore(provider SessionProvider, idleTTL, reapInterval, handleIdleTTL, tokenTTL, workerQueueTimeout time.Duration, opts Options) *flightAuthSessionStore { createFn := func(context.Context, string) (int32, *server.FlightExecutor, error) { return 0, nil, fmt.Errorf("session provider is not configured") } @@ -1070,25 +1065,21 @@ func newFlightAuthSessionStore(provider SessionProvider, idleTTL, reapInterval, createFn = provider.CreateSession destroyFn = provider.DestroySession } - isMaxWorkerFn := opts.IsMaxWorkersError - if isMaxWorkerFn == nil { - isMaxWorkerFn = func(error) bool { return false } - } s := &flightAuthSessionStore{ - provider: provider, - idleTTL: idleTTL, - reapInterval: reapInterval, - handleIdleTTL: handleIdleTTL, - tokenTTL: tokenTTL, - hooks: opts.Hooks, - isMaxWorkerFn: isMaxWorkerFn, - createSessionFn: createFn, - destroySessionFn: destroyFn, - sessions: make(map[string]*flightClientSession), - byKey: make(map[string]string), - stopCh: make(chan struct{}), - doneCh: make(chan struct{}), + provider: provider, + idleTTL: idleTTL, + reapInterval: reapInterval, + handleIdleTTL: handleIdleTTL, + tokenTTL: tokenTTL, + workerQueueTimeout: workerQueueTimeout, + hooks: opts.Hooks, + createSessionFn: createFn, + destroySessionFn: destroyFn, + sessions: make(map[string]*flightClientSession), + byKey: make(map[string]string), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), } go s.reapLoop() return s @@ -1099,7 +1090,33 @@ func (s *flightAuthSessionStore) Create(ctx context.Context, username string) (* if err != nil { return nil, fmt.Errorf("generate bootstrap nonce: %w", err) } - return s.GetOrCreate(ctx, "bootstrap|"+username+"|"+bootstrapNonce, username) + key := "bootstrap|" + username + "|" + bootstrapNonce + + // 1. Try a fast acquisition first. If slots are available (busy or idle), this succeeds immediately. + fastCtx, fastCancel := context.WithTimeout(ctx, 50*time.Millisecond) + sess, err := s.GetOrCreate(fastCtx, key, username) + fastCancel() + if err == nil { + return sess, nil + } + + // 2. Acquisition failed or is queuing. Trigger a forced reap of IDLE sessions + // to free up slots for this and other queued requests. + if ctx.Err() == nil { + reaped := s.reapIdle(time.Now(), ReapTriggerForced) + if reaped > 0 { + slog.Info("Flight auth session store forced idle reap due to worker exhaustion.", "reaped_sessions", reaped) + } + } + + // 3. Apply worker queue timeout for the final (potentially blocking) attempt. + if s.workerQueueTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, s.workerQueueTimeout) + defer cancel() + } + + return s.GetOrCreate(ctx, key, username) } func (s *flightAuthSessionStore) notifySessionCountChanged(count int) { @@ -1117,12 +1134,6 @@ func (s *flightAuthSessionStore) notifySessionsReaped(trigger string, count int) } } -func (s *flightAuthSessionStore) notifyMaxWorkerRetry(outcome string) { - if s.hooks.OnMaxWorkersRetry != nil { - s.hooks.OnMaxWorkersRetry(outcome) - } -} - func (s *flightAuthSessionStore) GetOrCreate(ctx context.Context, key, username string) (*flightClientSession, error) { existing, ok := s.getExistingByKey(key) if ok { @@ -1131,26 +1142,6 @@ func (s *flightAuthSessionStore) GetOrCreate(ctx context.Context, key, username } pid, executor, err := s.createSessionFn(ctx, username) - if err != nil && s.isMaxWorkerFn(err) { - s.notifyMaxWorkerRetry(MaxWorkersRetryAttempted) - reaped := s.reapIdle(time.Now(), ReapTriggerForced) - if reaped > 0 { - slog.Info("Flight auth session store forced idle reap before retry.", "reaped_sessions", reaped) - } - pid, executor, err = s.createSessionFn(ctx, username) - if err != nil { - if s.isMaxWorkerFn(err) { - if existing, ok := s.waitForExisting(ctx, key, 5*time.Second); ok { - existing.touch() - s.notifyMaxWorkerRetry(MaxWorkersRetrySucceeded) - return existing, nil - } - } - s.notifyMaxWorkerRetry(MaxWorkersRetryFailed) - } else { - s.notifyMaxWorkerRetry(MaxWorkersRetrySucceeded) - } - } if err != nil { return nil, err } @@ -1282,24 +1273,6 @@ func (s *flightAuthSessionStore) ensureMapsLocked() { } } -func (s *flightAuthSessionStore) waitForExisting(ctx context.Context, key string, timeout time.Duration) (*flightClientSession, bool) { - deadline := time.Now().Add(timeout) - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - - for time.Now().Before(deadline) { - if existing, ok := s.getExistingByKey(key); ok { - return existing, true - } - select { - case <-ctx.Done(): - return nil, false - case <-ticker.C: - } - } - return s.getExistingByKey(key) -} - func (s *flightAuthSessionStore) Close() { s.stopOnce.Do(func() { close(s.stopCh) diff --git a/server/flightsqlingress/ingress_test.go b/server/flightsqlingress/ingress_test.go index 6a2a57a..db595dd 100644 --- a/server/flightsqlingress/ingress_test.go +++ b/server/flightsqlingress/ingress_test.go @@ -21,8 +21,6 @@ import ( "google.golang.org/grpc/status" ) -var errTestMaxWorkersReached = errors.New("max workers reached") - type testExecResult struct { affected int64 err error @@ -542,87 +540,6 @@ func TestSessionFromContextFailedAndSuccessfulAuthUpdateRateLimiter(t *testing.T } } -func TestFlightAuthSessionStoreRetriesAfterForcedReapOnMaxWorkers(t *testing.T) { - stale := newFlightClientSession(9001, "postgres", nil) - stale.lastUsed.Store(time.Now().Add(-1 * time.Hour).UnixNano()) - - createCalls := 0 - destroyed := make([]int32, 0, 2) - store := &flightAuthSessionStore{ - idleTTL: time.Minute, - reapInterval: time.Hour, - handleIdleTTL: time.Minute, - isMaxWorkerFn: func(err error) bool { return errors.Is(err, errTestMaxWorkersReached) }, - sessions: map[string]*flightClientSession{ - "stale-session": stale, - }, - stopCh: make(chan struct{}), - doneCh: make(chan struct{}), - } - store.createSessionFn = func(context.Context, string) (int32, *server.FlightExecutor, error) { - createCalls++ - if createCalls == 1 { - return 0, nil, fmt.Errorf("acquire worker: %w", errTestMaxWorkersReached) - } - return 9002, nil, nil - } - store.destroySessionFn = func(pid int32) { - destroyed = append(destroyed, pid) - } - - session, err := store.GetOrCreate(context.Background(), "new", "postgres") - if err != nil { - t.Fatalf("expected GetOrCreate retry to succeed, got error %v", err) - } - if session == nil { - t.Fatalf("expected non-nil session after retry") - } - if createCalls != 2 { - t.Fatalf("expected 2 create attempts, got %d", createCalls) - } - if len(destroyed) != 1 || destroyed[0] != 9001 { - t.Fatalf("expected stale session to be reaped before retry, got destroyed=%v", destroyed) - } -} - -func TestFlightAuthSessionStoreRetryHookEvents(t *testing.T) { - outcomes := make([]string, 0, 2) - store := &flightAuthSessionStore{ - idleTTL: time.Minute, - reapInterval: time.Hour, - handleIdleTTL: time.Minute, - isMaxWorkerFn: func(err error) bool { return errors.Is(err, errTestMaxWorkersReached) }, - sessions: make(map[string]*flightClientSession), - stopCh: make(chan struct{}), - doneCh: make(chan struct{}), - hooks: Hooks{ - OnMaxWorkersRetry: func(outcome string) { - outcomes = append(outcomes, outcome) - }, - }, - } - - createCalls := 0 - store.createSessionFn = func(context.Context, string) (int32, *server.FlightExecutor, error) { - createCalls++ - if createCalls == 1 { - return 0, nil, fmt.Errorf("acquire worker: %w", errTestMaxWorkersReached) - } - return 1234, nil, nil - } - store.destroySessionFn = func(int32) {} - - if _, err := store.GetOrCreate(context.Background(), "k", "postgres"); err != nil { - t.Fatalf("expected retry path to succeed, got %v", err) - } - if len(outcomes) != 2 { - t.Fatalf("expected 2 retry outcomes, got %d (%v)", len(outcomes), outcomes) - } - if outcomes[0] != MaxWorkersRetryAttempted || outcomes[1] != MaxWorkersRetrySucceeded { - t.Fatalf("unexpected retry outcomes: %v", outcomes) - } -} - func TestFlightAuthSessionStoreReapHookReceivesTrigger(t *testing.T) { stale := newFlightClientSession(1234, "postgres", nil) stale.lastUsed.Store(time.Now().Add(-1 * time.Hour).UnixNano()) @@ -661,68 +578,6 @@ func TestFlightAuthSessionStoreReapHookReceivesTrigger(t *testing.T) { } } -func TestFlightAuthSessionStoreConcurrentCreateReusesExistingAfterMaxWorkers(t *testing.T) { - store := &flightAuthSessionStore{ - idleTTL: time.Minute, - reapInterval: time.Hour, - handleIdleTTL: time.Minute, - isMaxWorkerFn: func(err error) bool { return errors.Is(err, errTestMaxWorkersReached) }, - sessions: make(map[string]*flightClientSession), - stopCh: make(chan struct{}), - doneCh: make(chan struct{}), - } - - var createCalls atomic.Int32 - firstCreateStarted := make(chan struct{}) - releaseFirstCreate := make(chan struct{}) - store.createSessionFn = func(context.Context, string) (int32, *server.FlightExecutor, error) { - callNum := createCalls.Add(1) - if callNum == 1 { - close(firstCreateStarted) - <-releaseFirstCreate - return 1001, nil, nil - } - return 0, nil, fmt.Errorf("acquire worker: %w", errTestMaxWorkersReached) - } - store.destroySessionFn = func(int32) {} - - var wg sync.WaitGroup - wg.Add(2) - - var s1, s2 *flightClientSession - var err1, err2 error - - go func() { - defer wg.Done() - s1, err1 = store.GetOrCreate(context.Background(), "shared-key", "postgres") - }() - - <-firstCreateStarted - go func() { - defer wg.Done() - s2, err2 = store.GetOrCreate(context.Background(), "shared-key", "postgres") - }() - - close(releaseFirstCreate) - wg.Wait() - - if err1 != nil { - t.Fatalf("first GetOrCreate failed: %v", err1) - } - if err2 != nil { - t.Fatalf("second GetOrCreate failed: %v", err2) - } - if s1 == nil || s2 == nil { - t.Fatalf("expected non-nil sessions, got s1=%v s2=%v", s1, s2) - } - if s1 != s2 { - t.Fatalf("expected both callers to share the same session pointer") - } - if createCalls.Load() < 1 { - t.Fatalf("expected at least 1 create attempt, got %d", createCalls.Load()) - } -} - func TestFlightAuthSessionStoreReapKeepsSessionWithFreshHandle(t *testing.T) { cs := newFlightClientSession(1234, "postgres", nil) cs.lastUsed.Store(time.Now().Add(-1 * time.Hour).UnixNano()) diff --git a/tests/controlplane/flight_ingress_test.go b/tests/controlplane/flight_ingress_test.go index c87a05a..57354fe 100644 --- a/tests/controlplane/flight_ingress_test.go +++ b/tests/controlplane/flight_ingress_test.go @@ -117,7 +117,7 @@ func requireGetTablesIncludeSchema(t *testing.T, client *flightsql.Client, ctx c func TestFlightIngressIncludeSchemaLowWorkerRegression(t *testing.T) { h := startControlPlane(t, cpOpts{ flightPort: freePort(t), - maxWorkers: 1, + maxWorkers: 3, }) const goroutines = 3