diff --git a/controlplane/control.go b/controlplane/control.go index 6e41464..06058b6 100644 --- a/controlplane/control.go +++ b/controlplane/control.go @@ -489,6 +489,10 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) { } defer cp.sessions.DestroySession(pid) + // Register the TCP connection so OnWorkerCrash can close it to unblock + // the message loop if the backing worker dies. + cp.sessions.SetConnCloser(pid, tlsConn) + secretKey := server.GenerateSecretKey() // Create clientConn with FlightExecutor diff --git a/controlplane/session_mgr.go b/controlplane/session_mgr.go index c3f8e30..39af506 100644 --- a/controlplane/session_mgr.go +++ b/controlplane/session_mgr.go @@ -3,6 +3,7 @@ package controlplane import ( "context" "fmt" + "io" "log/slog" "sync" "sync/atomic" @@ -17,6 +18,7 @@ type ManagedSession struct { WorkerID int SessionToken string Executor *server.FlightExecutor + connCloser io.Closer // TCP connection, closed on worker crash to unblock the message loop } // SessionManager tracks all active sessions and their worker assignments. @@ -138,12 +140,22 @@ func (sm *SessionManager) DestroySession(pid int32) { } } -// OnWorkerCrash handles a worker crash by sending errors to all affected sessions. +// OnWorkerCrash handles a worker crash by marking all affected executors as +// dead and notifying sessions. Executors are marked dead BEFORE the shared +// gRPC client is closed to prevent nil-pointer panics from concurrent RPCs. // errorFn is called for each affected session to send an error to the client. func (sm *SessionManager) OnWorkerCrash(workerID int, errorFn func(pid int32)) { sm.mu.Lock() pids := make([]int32, len(sm.byWorker[workerID])) copy(pids, sm.byWorker[workerID]) + + // Mark all executors as dead first (under lock) so any concurrent RPC + // sees the dead flag before the gRPC client is closed. + for _, pid := range pids { + if s, ok := sm.sessions[pid]; ok && s.Executor != nil { + s.Executor.MarkDead() + } + } sm.mu.Unlock() slog.Warn("Worker crashed, notifying sessions.", "worker", workerID, "sessions", len(pids)) @@ -157,6 +169,15 @@ func (sm *SessionManager) OnWorkerCrash(workerID int, errorFn func(pid int32)) { if session.Executor != nil { _ = session.Executor.Close() } + // Close the TCP connection to unblock the message loop's read. + // This causes the session goroutine to exit instead of looping + // with ErrWorkerDead on every query. The deferred close in + // handleConnection will also call Close() on the same conn; + // that's harmless (net.Conn.Close on a closed socket returns + // an error which is discarded). + if session.connCloser != nil { + _ = session.connCloser.Close() + } } sm.mu.Unlock() } @@ -171,6 +192,17 @@ func (sm *SessionManager) OnWorkerCrash(workerID int, errorFn func(pid int32)) { } } +// SetConnCloser registers the client's TCP connection so it can be closed +// when the backing worker crashes. This unblocks the message loop's read, +// causing it to exit cleanly instead of looping on ErrWorkerDead. +func (sm *SessionManager) SetConnCloser(pid int32, closer io.Closer) { + sm.mu.Lock() + defer sm.mu.Unlock() + if s, ok := sm.sessions[pid]; ok { + s.connCloser = closer + } +} + // SessionCount returns the number of active sessions. func (sm *SessionManager) SessionCount() int { sm.mu.RLock() diff --git a/controlplane/session_mgr_test.go b/controlplane/session_mgr_test.go new file mode 100644 index 0000000..b0c2404 --- /dev/null +++ b/controlplane/session_mgr_test.go @@ -0,0 +1,253 @@ +package controlplane + +import ( + "runtime" + "strings" + "sync/atomic" + "testing" + + "github.com/posthog/duckgres/server" +) + +// mockCloser tracks whether Close was called. +type mockCloser struct { + closed atomic.Bool +} + +func (m *mockCloser) Close() error { + m.closed.Store(true) + return nil +} + +func TestOnWorkerCrash_MarksExecutorsDead(t *testing.T) { + pool := &FlightWorkerPool{ + workers: make(map[int]*ManagedWorker), + } + sm := NewSessionManager(pool, nil) + + executor := &server.FlightExecutor{} + pid := int32(1001) + + sm.mu.Lock() + sm.sessions[pid] = &ManagedSession{ + PID: pid, + WorkerID: 5, + Executor: executor, + } + sm.byWorker[5] = []int32{pid} + sm.mu.Unlock() + + var notifiedPIDs []int32 + sm.OnWorkerCrash(5, func(pid int32) { + notifiedPIDs = append(notifiedPIDs, pid) + }) + + // Executor should be marked dead + if !executor.IsDead() { + t.Fatal("expected executor to be marked dead after OnWorkerCrash") + } + + // errorFn should have been called + if len(notifiedPIDs) != 1 || notifiedPIDs[0] != pid { + t.Fatalf("expected errorFn called with pid %d, got %v", pid, notifiedPIDs) + } + + // Session should be removed + if sm.SessionCount() != 0 { + t.Fatalf("expected 0 sessions after crash, got %d", sm.SessionCount()) + } +} + +func TestOnWorkerCrash_ClosesConnections(t *testing.T) { + pool := &FlightWorkerPool{ + workers: make(map[int]*ManagedWorker), + } + sm := NewSessionManager(pool, nil) + + conn := &mockCloser{} + executor := &server.FlightExecutor{} + pid := int32(1002) + + sm.mu.Lock() + sm.sessions[pid] = &ManagedSession{ + PID: pid, + WorkerID: 7, + Executor: executor, + connCloser: conn, + } + sm.byWorker[7] = []int32{pid} + sm.mu.Unlock() + + sm.OnWorkerCrash(7, func(pid int32) {}) + + if !conn.closed.Load() { + t.Fatal("expected TCP connection to be closed on worker crash") + } +} + +func TestOnWorkerCrash_MultipleSessions(t *testing.T) { + pool := &FlightWorkerPool{ + workers: make(map[int]*ManagedWorker), + } + sm := NewSessionManager(pool, nil) + + exec1 := &server.FlightExecutor{} + exec2 := &server.FlightExecutor{} + conn1 := &mockCloser{} + conn2 := &mockCloser{} + + sm.mu.Lock() + sm.sessions[1001] = &ManagedSession{PID: 1001, WorkerID: 3, Executor: exec1, connCloser: conn1} + sm.sessions[1002] = &ManagedSession{PID: 1002, WorkerID: 3, Executor: exec2, connCloser: conn2} + sm.byWorker[3] = []int32{1001, 1002} + sm.mu.Unlock() + + sm.OnWorkerCrash(3, func(pid int32) {}) + + if !exec1.IsDead() || !exec2.IsDead() { + t.Fatal("expected both executors to be marked dead") + } + if !conn1.closed.Load() || !conn2.closed.Load() { + t.Fatal("expected both connections to be closed") + } + if sm.SessionCount() != 0 { + t.Fatalf("expected 0 sessions, got %d", sm.SessionCount()) + } +} + +func TestSetConnCloser(t *testing.T) { + pool := &FlightWorkerPool{ + workers: make(map[int]*ManagedWorker), + } + sm := NewSessionManager(pool, nil) + + pid := int32(1003) + sm.mu.Lock() + sm.sessions[pid] = &ManagedSession{PID: pid, WorkerID: 1} + sm.byWorker[1] = []int32{pid} + sm.mu.Unlock() + + conn := &mockCloser{} + sm.SetConnCloser(pid, conn) + + // Verify it was set by triggering a crash + sm.OnWorkerCrash(1, func(pid int32) {}) + + if !conn.closed.Load() { + t.Fatal("expected connection registered via SetConnCloser to be closed on crash") + } +} + +func TestSetConnCloser_UnknownPID(t *testing.T) { + pool := &FlightWorkerPool{ + workers: make(map[int]*ManagedWorker), + } + sm := NewSessionManager(pool, nil) + + // Should not panic when PID doesn't exist + conn := &mockCloser{} + sm.SetConnCloser(9999, conn) + + if conn.closed.Load() { + t.Fatal("connection should not be closed for unknown PID") + } +} + +func TestRecoverWorkerPanic_NilPointer(t *testing.T) { + var err error + func() { + defer recoverWorkerPanic(&err) + var i *int + _ = *i //nolint:govet + }() + + if err == nil { + t.Fatal("expected error from recovered nil pointer panic") + } + if !strings.Contains(err.Error(), "worker likely crashed") { + t.Fatalf("expected crash message, got: %v", err) + } +} + +func TestRecoverWorkerPanic_NonNilPointerRePanics(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatal("expected re-panic for non-nil-pointer panic") + } + if s, ok := r.(string); !ok || s != "unrelated panic" { + t.Fatalf("expected original panic value, got: %v", r) + } + }() + + var err error + func() { + defer recoverWorkerPanic(&err) + panic("unrelated panic") + }() + + t.Fatal("should not reach here") +} + +func TestRecoverWorkerPanic_RuntimeErrorRePanics(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatal("expected re-panic for non-nil-pointer runtime error") + } + if re, ok := r.(runtime.Error); !ok { + t.Fatalf("expected runtime.Error, got %T: %v", r, r) + } else if strings.Contains(re.Error(), "nil pointer") { + t.Fatal("this test should use a non-nil-pointer runtime error") + } + }() + + var err error + func() { + defer recoverWorkerPanic(&err) + s := []int{} + _ = s[1] //nolint:govet + }() + + t.Fatal("should not reach here") +} + +func TestDestroySessionAfterOnWorkerCrash(t *testing.T) { + // Verify that DestroySession is a safe no-op when OnWorkerCrash already + // cleaned up the session. This is the exact production sequence: + // OnWorkerCrash runs from the health check, then the deferred + // DestroySession runs when handleConnection returns. + pool := &FlightWorkerPool{ + workers: make(map[int]*ManagedWorker), + } + sm := NewSessionManager(pool, nil) + + conn := &mockCloser{} + executor := &server.FlightExecutor{} + pid := int32(1010) + + sm.mu.Lock() + sm.sessions[pid] = &ManagedSession{ + PID: pid, + WorkerID: 9, + Executor: executor, + connCloser: conn, + } + sm.byWorker[9] = []int32{pid} + sm.mu.Unlock() + + // Simulate crash cleanup + sm.OnWorkerCrash(9, func(pid int32) {}) + + if sm.SessionCount() != 0 { + t.Fatal("expected 0 sessions after OnWorkerCrash") + } + + // Now DestroySession runs (from deferred call in handleConnection). + // Should be a no-op — no panic, no double-close of worker resources. + sm.DestroySession(pid) + + if sm.SessionCount() != 0 { + t.Fatal("expected 0 sessions after DestroySession") + } +} diff --git a/controlplane/worker_mgr.go b/controlplane/worker_mgr.go index 378972c..0033219 100644 --- a/controlplane/worker_mgr.go +++ b/controlplane/worker_mgr.go @@ -10,6 +10,8 @@ import ( "log/slog" "os" "os/exec" + "runtime" + "strings" "sync" "time" @@ -482,10 +484,18 @@ func (p *FlightWorkerPool) HealthCheckLoop(ctx context.Context, interval time.Du } _ = os.Remove(w.socketPath) default: - // Worker is alive, do a health check - hctx, cancel := context.WithTimeout(ctx, 3*time.Second) - err := doHealthCheck(hctx, w.client) - cancel() + // Worker is alive, do a health check. + // Recover nil-pointer panics: w.client.Close() (from a + // concurrent crash/retire) nils out FlightServiceClient, + // racing with the DoAction call inside doHealthCheck. + var healthErr error + func() { + defer recoverWorkerPanic(&healthErr) + hctx, cancel := context.WithTimeout(ctx, 3*time.Second) + healthErr = doHealthCheck(hctx, w.client) + cancel() + }() + err := healthErr if err != nil { mu.Lock() @@ -540,8 +550,23 @@ func (p *FlightWorkerPool) HealthCheckLoop(ctx context.Context, interval time.Du } } +// recoverWorkerPanic converts a nil-pointer panic from a closed Flight SQL +// client into an error. Same race as FlightExecutor: arrow-go Close() nils out +// FlightServiceClient, and concurrent DoAction calls on the shared client panic. +func recoverWorkerPanic(err *error) { + if r := recover(); r != nil { + if re, ok := r.(runtime.Error); ok && strings.Contains(re.Error(), "nil pointer") { + *err = fmt.Errorf("worker client panic (worker likely crashed): %v", r) + return + } + panic(r) + } +} + // CreateSession creates a new session on the given worker. -func (w *ManagedWorker) CreateSession(ctx context.Context, username string) (string, error) { +func (w *ManagedWorker) CreateSession(ctx context.Context, username string) (token string, err error) { + defer recoverWorkerPanic(&err) + body, _ := json.Marshal(map[string]string{"username": username}) stream, err := w.client.Client.DoAction(ctx, &flight.Action{ @@ -568,7 +593,9 @@ func (w *ManagedWorker) CreateSession(ctx context.Context, username string) (str } // DestroySession destroys a session on the worker. -func (w *ManagedWorker) DestroySession(ctx context.Context, sessionToken string) error { +func (w *ManagedWorker) DestroySession(ctx context.Context, sessionToken string) (err error) { + defer recoverWorkerPanic(&err) + body, _ := json.Marshal(map[string]string{"session_token": sessionToken}) stream, err := w.client.Client.DoAction(ctx, &flight.Action{ diff --git a/server/flight_executor.go b/server/flight_executor.go index 362bf77..fff8690 100644 --- a/server/flight_executor.go +++ b/server/flight_executor.go @@ -3,10 +3,13 @@ package server import ( "context" "encoding/hex" + "errors" "fmt" "math/big" + "runtime" "strings" "sync" + "sync/atomic" "time" "github.com/apache/arrow-go/v18/arrow" @@ -24,6 +27,9 @@ import ( // DuckDB query results can easily exceed the default 4MB limit. const MaxGRPCMessageSize = 1 << 30 // 1GB +// ErrWorkerDead is returned when the backing worker process has crashed. +var ErrWorkerDead = errors.New("flight worker is dead") + // FlightExecutor implements QueryExecutor backed by an Arrow Flight SQL client. // It routes queries to a duckdb-service worker process over a Unix socket. type FlightExecutor struct { @@ -31,6 +37,10 @@ type FlightExecutor struct { sessionToken string alloc memory.Allocator ownsClient bool // if true, Close() closes the client + + // dead is set to true when the backing worker crashes. Once set, all + // RPC methods return ErrWorkerDead without touching the gRPC client. + dead atomic.Bool } // NewFlightExecutor creates a FlightExecutor connected to the given address. @@ -74,12 +84,43 @@ func NewFlightExecutorFromClient(client *flightsql.Client, sessionToken string) } } +// MarkDead marks this executor's backing worker as dead. All subsequent RPC +// calls will return ErrWorkerDead without touching the (possibly closed) gRPC client. +func (e *FlightExecutor) MarkDead() { + e.dead.Store(true) +} + +// IsDead reports whether this executor has been marked dead. +func (e *FlightExecutor) IsDead() bool { + return e.dead.Load() +} + // withSession adds the session token to the gRPC context. func (e *FlightExecutor) withSession(ctx context.Context) context.Context { return metadata.AppendToOutgoingContext(ctx, "x-duckgres-session", e.sessionToken) } -func (e *FlightExecutor) QueryContext(ctx context.Context, query string, args ...any) (RowSet, error) { +// recoverClientPanic converts a nil-pointer panic from a closed Flight SQL +// client into an error. The arrow-go Close() method nils out the embedded +// FlightServiceClient, so any concurrent RPC on the shared client panics. +// Only nil-pointer dereferences are recovered; other panics are re-raised +// to preserve stack traces for unrelated programmer errors. +func recoverClientPanic(err *error) { + if r := recover(); r != nil { + if re, ok := r.(runtime.Error); ok && strings.Contains(re.Error(), "nil pointer") { + *err = fmt.Errorf("flight client panic (worker likely crashed): %v", r) + return + } + panic(r) + } +} + +func (e *FlightExecutor) QueryContext(ctx context.Context, query string, args ...any) (rs RowSet, err error) { + if e.dead.Load() { + return nil, ErrWorkerDead + } + defer recoverClientPanic(&err) + if len(args) > 0 { query = interpolateArgs(query, args) } @@ -112,7 +153,12 @@ func (e *FlightExecutor) QueryContext(ctx context.Context, query string, args .. }, nil } -func (e *FlightExecutor) ExecContext(ctx context.Context, query string, args ...any) (ExecResult, error) { +func (e *FlightExecutor) ExecContext(ctx context.Context, query string, args ...any) (result ExecResult, err error) { + if e.dead.Load() { + return nil, ErrWorkerDead + } + defer recoverClientPanic(&err) + if len(args) > 0 { query = interpolateArgs(query, args) } diff --git a/server/flight_executor_test.go b/server/flight_executor_test.go new file mode 100644 index 0000000..e4d1d5b --- /dev/null +++ b/server/flight_executor_test.go @@ -0,0 +1,134 @@ +package server + +import ( + "context" + "errors" + "runtime" + "strings" + "testing" +) + +func TestFlightExecutorMarkDead_QueryContext(t *testing.T) { + // A dead executor should return ErrWorkerDead without touching the client. + e := &FlightExecutor{} // client is nil — would panic if accessed + e.MarkDead() + + _, err := e.QueryContext(context.Background(), "SELECT 1") + if !errors.Is(err, ErrWorkerDead) { + t.Fatalf("expected ErrWorkerDead, got %v", err) + } +} + +func TestFlightExecutorMarkDead_ExecContext(t *testing.T) { + e := &FlightExecutor{} + e.MarkDead() + + _, err := e.ExecContext(context.Background(), "SET x = 1") + if !errors.Is(err, ErrWorkerDead) { + t.Fatalf("expected ErrWorkerDead, got %v", err) + } +} + +func TestFlightExecutorMarkDeadIdempotent(t *testing.T) { + e := &FlightExecutor{} + e.MarkDead() + e.MarkDead() // should not panic + + _, err := e.QueryContext(context.Background(), "SELECT 1") + if !errors.Is(err, ErrWorkerDead) { + t.Fatalf("expected ErrWorkerDead after double MarkDead, got %v", err) + } +} + +func TestRecoverClientPanic_NilPointer(t *testing.T) { + var err error + func() { + defer recoverClientPanic(&err) + // Simulate the nil pointer dereference that arrow-go causes + var i *int + _ = *i //nolint:govet + }() + + if err == nil { + t.Fatal("expected error from recovered nil pointer panic") + } + if !strings.Contains(err.Error(), "worker likely crashed") { + t.Fatalf("expected crash message, got: %v", err) + } +} + +func TestRecoverClientPanic_NonNilPointerRePanics(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatal("expected re-panic for non-nil-pointer panic") + } + // Should be the original string panic, not a wrapped error + if s, ok := r.(string); !ok || s != "some other panic" { + t.Fatalf("expected original panic value, got: %v", r) + } + }() + + var err error + func() { + defer recoverClientPanic(&err) + panic("some other panic") + }() + + t.Fatal("should not reach here — panic should propagate") +} + +func TestRecoverClientPanic_RuntimeErrorRePanics(t *testing.T) { + // runtime.Error that is NOT a nil pointer should re-panic + defer func() { + r := recover() + if r == nil { + t.Fatal("expected re-panic for non-nil-pointer runtime error") + } + if re, ok := r.(runtime.Error); !ok { + t.Fatalf("expected runtime.Error, got %T: %v", r, r) + } else if strings.Contains(re.Error(), "nil pointer") { + t.Fatal("this test should use a non-nil-pointer runtime error") + } + }() + + var err error + func() { + defer recoverClientPanic(&err) + // Index out of range is a runtime.Error but not a nil pointer + s := []int{} + _ = s[1] //nolint:govet + }() + + t.Fatal("should not reach here — runtime error should re-panic") +} + +func TestFlightExecutorNilClient_QueryContextRecovers(t *testing.T) { + // Simulate the exact scenario: executor with a nil client (as if Close() + // nilled it out). QueryContext should recover the panic, not crash. + e := &FlightExecutor{ + client: nil, // simulates closed client + } + + _, err := e.QueryContext(context.Background(), "SELECT 1") + if err == nil { + t.Fatal("expected error from nil client") + } + if !strings.Contains(err.Error(), "worker likely crashed") { + t.Fatalf("expected crash recovery message, got: %v", err) + } +} + +func TestFlightExecutorNilClient_ExecContextRecovers(t *testing.T) { + e := &FlightExecutor{ + client: nil, + } + + _, err := e.ExecContext(context.Background(), "SET x = 1") + if err == nil { + t.Fatal("expected error from nil client") + } + if !strings.Contains(err.Error(), "worker likely crashed") { + t.Fatalf("expected crash recovery message, got: %v", err) + } +}