Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions controlplane/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
}
defer cp.sessions.DestroySession(pid)

// Register the TCP connection so OnWorkerCrash can close it to unblock
// the message loop if the backing worker dies.
cp.sessions.SetConnCloser(pid, tlsConn)

secretKey := server.GenerateSecretKey()

// Create clientConn with FlightExecutor
Expand Down
34 changes: 33 additions & 1 deletion controlplane/session_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package controlplane
import (
"context"
"fmt"
"io"
"log/slog"
"sync"
"sync/atomic"
Expand All @@ -17,6 +18,7 @@ type ManagedSession struct {
WorkerID int
SessionToken string
Executor *server.FlightExecutor
connCloser io.Closer // TCP connection, closed on worker crash to unblock the message loop
}

// SessionManager tracks all active sessions and their worker assignments.
Expand Down Expand Up @@ -138,12 +140,22 @@ func (sm *SessionManager) DestroySession(pid int32) {
}
}

// OnWorkerCrash handles a worker crash by sending errors to all affected sessions.
// OnWorkerCrash handles a worker crash by marking all affected executors as
// dead and notifying sessions. Executors are marked dead BEFORE the shared
// gRPC client is closed to prevent nil-pointer panics from concurrent RPCs.
// errorFn is called for each affected session to send an error to the client.
func (sm *SessionManager) OnWorkerCrash(workerID int, errorFn func(pid int32)) {
sm.mu.Lock()
pids := make([]int32, len(sm.byWorker[workerID]))
copy(pids, sm.byWorker[workerID])

// Mark all executors as dead first (under lock) so any concurrent RPC
// sees the dead flag before the gRPC client is closed.
for _, pid := range pids {
if s, ok := sm.sessions[pid]; ok && s.Executor != nil {
s.Executor.MarkDead()
}
}
sm.mu.Unlock()

slog.Warn("Worker crashed, notifying sessions.", "worker", workerID, "sessions", len(pids))
Expand All @@ -157,6 +169,15 @@ func (sm *SessionManager) OnWorkerCrash(workerID int, errorFn func(pid int32)) {
if session.Executor != nil {
_ = session.Executor.Close()
}
// Close the TCP connection to unblock the message loop's read.
// This causes the session goroutine to exit instead of looping
// with ErrWorkerDead on every query. The deferred close in
// handleConnection will also call Close() on the same conn;
// that's harmless (net.Conn.Close on a closed socket returns
// an error which is discarded).
if session.connCloser != nil {
_ = session.connCloser.Close()
}
}
sm.mu.Unlock()
}
Expand All @@ -171,6 +192,17 @@ func (sm *SessionManager) OnWorkerCrash(workerID int, errorFn func(pid int32)) {
}
}

// SetConnCloser registers the client's TCP connection so it can be closed
// when the backing worker crashes. This unblocks the message loop's read,
// causing it to exit cleanly instead of looping on ErrWorkerDead.
func (sm *SessionManager) SetConnCloser(pid int32, closer io.Closer) {
sm.mu.Lock()
defer sm.mu.Unlock()
if s, ok := sm.sessions[pid]; ok {
s.connCloser = closer
}
}

// SessionCount returns the number of active sessions.
func (sm *SessionManager) SessionCount() int {
sm.mu.RLock()
Expand Down
253 changes: 253 additions & 0 deletions controlplane/session_mgr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
package controlplane

import (
"runtime"
"strings"
"sync/atomic"
"testing"

"github.com/posthog/duckgres/server"
)

// mockCloser tracks whether Close was called.
type mockCloser struct {
closed atomic.Bool
}

func (m *mockCloser) Close() error {
m.closed.Store(true)
return nil
}

func TestOnWorkerCrash_MarksExecutorsDead(t *testing.T) {
pool := &FlightWorkerPool{
workers: make(map[int]*ManagedWorker),
}
sm := NewSessionManager(pool, nil)

executor := &server.FlightExecutor{}
pid := int32(1001)

sm.mu.Lock()
sm.sessions[pid] = &ManagedSession{
PID: pid,
WorkerID: 5,
Executor: executor,
}
sm.byWorker[5] = []int32{pid}
sm.mu.Unlock()

var notifiedPIDs []int32
sm.OnWorkerCrash(5, func(pid int32) {
notifiedPIDs = append(notifiedPIDs, pid)
})

// Executor should be marked dead
if !executor.IsDead() {
t.Fatal("expected executor to be marked dead after OnWorkerCrash")
}

// errorFn should have been called
if len(notifiedPIDs) != 1 || notifiedPIDs[0] != pid {
t.Fatalf("expected errorFn called with pid %d, got %v", pid, notifiedPIDs)
}

// Session should be removed
if sm.SessionCount() != 0 {
t.Fatalf("expected 0 sessions after crash, got %d", sm.SessionCount())
}
}

func TestOnWorkerCrash_ClosesConnections(t *testing.T) {
pool := &FlightWorkerPool{
workers: make(map[int]*ManagedWorker),
}
sm := NewSessionManager(pool, nil)

conn := &mockCloser{}
executor := &server.FlightExecutor{}
pid := int32(1002)

sm.mu.Lock()
sm.sessions[pid] = &ManagedSession{
PID: pid,
WorkerID: 7,
Executor: executor,
connCloser: conn,
}
sm.byWorker[7] = []int32{pid}
sm.mu.Unlock()

sm.OnWorkerCrash(7, func(pid int32) {})

if !conn.closed.Load() {
t.Fatal("expected TCP connection to be closed on worker crash")
}
}

func TestOnWorkerCrash_MultipleSessions(t *testing.T) {
pool := &FlightWorkerPool{
workers: make(map[int]*ManagedWorker),
}
sm := NewSessionManager(pool, nil)

exec1 := &server.FlightExecutor{}
exec2 := &server.FlightExecutor{}
conn1 := &mockCloser{}
conn2 := &mockCloser{}

sm.mu.Lock()
sm.sessions[1001] = &ManagedSession{PID: 1001, WorkerID: 3, Executor: exec1, connCloser: conn1}
sm.sessions[1002] = &ManagedSession{PID: 1002, WorkerID: 3, Executor: exec2, connCloser: conn2}
sm.byWorker[3] = []int32{1001, 1002}
sm.mu.Unlock()

sm.OnWorkerCrash(3, func(pid int32) {})

if !exec1.IsDead() || !exec2.IsDead() {
t.Fatal("expected both executors to be marked dead")
}
if !conn1.closed.Load() || !conn2.closed.Load() {
t.Fatal("expected both connections to be closed")
}
if sm.SessionCount() != 0 {
t.Fatalf("expected 0 sessions, got %d", sm.SessionCount())
}
}

func TestSetConnCloser(t *testing.T) {
pool := &FlightWorkerPool{
workers: make(map[int]*ManagedWorker),
}
sm := NewSessionManager(pool, nil)

pid := int32(1003)
sm.mu.Lock()
sm.sessions[pid] = &ManagedSession{PID: pid, WorkerID: 1}
sm.byWorker[1] = []int32{pid}
sm.mu.Unlock()

conn := &mockCloser{}
sm.SetConnCloser(pid, conn)

// Verify it was set by triggering a crash
sm.OnWorkerCrash(1, func(pid int32) {})

if !conn.closed.Load() {
t.Fatal("expected connection registered via SetConnCloser to be closed on crash")
}
}

func TestSetConnCloser_UnknownPID(t *testing.T) {
pool := &FlightWorkerPool{
workers: make(map[int]*ManagedWorker),
}
sm := NewSessionManager(pool, nil)

// Should not panic when PID doesn't exist
conn := &mockCloser{}
sm.SetConnCloser(9999, conn)

if conn.closed.Load() {
t.Fatal("connection should not be closed for unknown PID")
}
}

func TestRecoverWorkerPanic_NilPointer(t *testing.T) {
var err error
func() {
defer recoverWorkerPanic(&err)
var i *int
_ = *i //nolint:govet
}()

if err == nil {
t.Fatal("expected error from recovered nil pointer panic")
}
if !strings.Contains(err.Error(), "worker likely crashed") {
t.Fatalf("expected crash message, got: %v", err)
}
}

func TestRecoverWorkerPanic_NonNilPointerRePanics(t *testing.T) {
defer func() {
r := recover()
if r == nil {
t.Fatal("expected re-panic for non-nil-pointer panic")
}
if s, ok := r.(string); !ok || s != "unrelated panic" {
t.Fatalf("expected original panic value, got: %v", r)
}
}()

var err error
func() {
defer recoverWorkerPanic(&err)
panic("unrelated panic")
}()

t.Fatal("should not reach here")
}

func TestRecoverWorkerPanic_RuntimeErrorRePanics(t *testing.T) {
defer func() {
r := recover()
if r == nil {
t.Fatal("expected re-panic for non-nil-pointer runtime error")
}
if re, ok := r.(runtime.Error); !ok {
t.Fatalf("expected runtime.Error, got %T: %v", r, r)
} else if strings.Contains(re.Error(), "nil pointer") {
t.Fatal("this test should use a non-nil-pointer runtime error")
}
}()

var err error
func() {
defer recoverWorkerPanic(&err)
s := []int{}
_ = s[1] //nolint:govet
}()

t.Fatal("should not reach here")
}

func TestDestroySessionAfterOnWorkerCrash(t *testing.T) {
// Verify that DestroySession is a safe no-op when OnWorkerCrash already
// cleaned up the session. This is the exact production sequence:
// OnWorkerCrash runs from the health check, then the deferred
// DestroySession runs when handleConnection returns.
pool := &FlightWorkerPool{
workers: make(map[int]*ManagedWorker),
}
sm := NewSessionManager(pool, nil)

conn := &mockCloser{}
executor := &server.FlightExecutor{}
pid := int32(1010)

sm.mu.Lock()
sm.sessions[pid] = &ManagedSession{
PID: pid,
WorkerID: 9,
Executor: executor,
connCloser: conn,
}
sm.byWorker[9] = []int32{pid}
sm.mu.Unlock()

// Simulate crash cleanup
sm.OnWorkerCrash(9, func(pid int32) {})

if sm.SessionCount() != 0 {
t.Fatal("expected 0 sessions after OnWorkerCrash")
}

// Now DestroySession runs (from deferred call in handleConnection).
// Should be a no-op — no panic, no double-close of worker resources.
sm.DestroySession(pid)

if sm.SessionCount() != 0 {
t.Fatal("expected 0 sessions after DestroySession")
}
}
Loading