Skip to content

Commit 2ee8da5

Browse files
authored
Refactor PostgreSQL handshake for JDBC compatibility (#275)
* Refactor PostgreSQL handshake for JDBC compatibility and optimize session creation * Fix GSS fallback handling and enforce default session limits * fix: address PR 275 review comments - Fix CancelRequest race condition by registering the query before CreateSession - Fix worker_mgr reaper to respect minWorkers as an idle pool size rather than total pool size - Fix uncancelled context leak for tmpCC in controlplane * Fix warm worker floor and queued session cancellation
1 parent 7d1b8a6 commit 2ee8da5

File tree

12 files changed

+468
-93
lines changed

12 files changed

+468
-93
lines changed

controlplane/control.go

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ func RunControlPlane(cfg ControlPlaneConfig) {
188188
minWorkers = maxWorkers
189189
}
190190

191-
pool := NewFlightWorkerPool(cfg.SocketDir, cfg.ConfigPath, maxWorkers)
191+
pool := NewFlightWorkerPool(cfg.SocketDir, cfg.ConfigPath, minWorkers, maxWorkers)
192192
pool.idleTimeout = cfg.WorkerIdleTimeout
193193

194194
// Import pre-bound sockets from handover, or pre-bind new ones.
@@ -392,6 +392,21 @@ func (cp *ControlPlane) acceptLoop() {
392392
}
393393
}
394394

395+
func createSessionWithRegisteredCancel(
396+
srv *server.Server,
397+
timeout time.Duration,
398+
key server.BackendKey,
399+
createFn func(context.Context) (int32, *server.FlightExecutor, error),
400+
) (int32, *server.FlightExecutor, error) {
401+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
402+
defer cancel()
403+
404+
srv.RegisterQuery(key, cancel)
405+
defer srv.UnregisterQuery(key)
406+
407+
return createFn(ctx)
408+
}
409+
395410
func (cp *ControlPlane) handleConnection(conn net.Conn) {
396411
remoteAddr := conn.RemoteAddr()
397412
slog.Info("Connection accepted.", "remote_addr", remoteAddr)
@@ -412,7 +427,8 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
412427
return
413428
}
414429

415-
// Read startup message to determine SSL vs cancel
430+
// Read startup message to determine SSL vs cancel.
431+
// readStartupFromRaw handles GSSENC probes by replying 'N' and continuing.
416432
params, err := readStartupFromRaw(conn)
417433
if err != nil {
418434
if err == io.EOF || errors.Is(err, io.EOF) {
@@ -543,14 +559,47 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
543559
server.RecordSuccessfulAuthAttempt(cp.rateLimiter, remoteAddr)
544560
slog.Info("User authenticated.", "user", username, "remote_addr", remoteAddr)
545561

562+
// Feed initial parameters and backend key data to the client IMMEDIATELY.
563+
// This keeps JDBC drivers happy while we perform the slow worker acquisition.
564+
pid := cp.sessions.ReservePID()
565+
secretKey := server.GenerateSecretKey()
566+
567+
// Use a temporary clientConn just to send initial params
568+
tmpCC := server.NewClientConn(cp.srv, nil, nil, writer, username, database, applicationName, nil, pid, secretKey, -1)
569+
defer server.CancelClientConn(tmpCC)
570+
server.SendInitialParams(tmpCC)
571+
if err := writer.Flush(); err != nil {
572+
slog.Error("Failed to flush initial params.", "remote_addr", remoteAddr, "error", err)
573+
return
574+
}
575+
546576
// Create session on a worker. The timeout controls how long we wait in the
547577
// worker queue when all slots are occupied.
548-
ctx, cancel := context.WithTimeout(context.Background(), cp.cfg.WorkerQueueTimeout)
549-
pid, executor, err := cp.sessions.CreateSession(ctx, username)
550-
cancel()
578+
// Pass resource limits to be applied immediately by the worker (one RPC).
579+
var (
580+
memLimit string
581+
threads int
582+
)
583+
if cp.rebalancer != nil {
584+
memLimit = cp.rebalancer.MemoryLimit()
585+
threads = cp.rebalancer.PerSessionThreads()
586+
}
587+
588+
_, executor, err := createSessionWithRegisteredCancel(
589+
cp.srv,
590+
cp.cfg.WorkerQueueTimeout,
591+
server.BackendKey{Pid: pid, SecretKey: secretKey},
592+
func(ctx context.Context) (int32, *server.FlightExecutor, error) {
593+
return cp.sessions.CreateSession(ctx, username, pid, memLimit, threads)
594+
},
595+
)
551596
if err != nil {
552597
slog.Error("Failed to create session.", "user", username, "remote_addr", remoteAddr, "error", err)
553-
_ = server.WriteErrorResponse(writer, "FATAL", "53300", "too many connections")
598+
if errors.Is(err, context.Canceled) {
599+
_ = server.WriteErrorResponse(writer, "FATAL", "57014", "canceling authentication due to user request")
600+
} else {
601+
_ = server.WriteErrorResponse(writer, "FATAL", "53300", "too many connections")
602+
}
554603
_ = writer.Flush()
555604
return
556605
}
@@ -560,14 +609,11 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
560609
// the message loop if the backing worker dies.
561610
cp.sessions.SetConnCloser(pid, tlsConn)
562611

563-
secretKey := server.GenerateSecretKey()
564-
565-
// Create clientConn with FlightExecutor
612+
// Create real clientConn with FlightExecutor and worker assignment
566613
workerID := cp.sessions.WorkerIDForPID(pid)
567614
cc := server.NewClientConn(cp.srv, tlsConn, reader, writer, username, database, applicationName, executor, pid, secretKey, workerID)
568615

569-
// Send initial parameters and ReadyForQuery
570-
server.SendInitialParams(cc)
616+
// Send ReadyForQuery to signal that the handshake is complete
571617
if err := server.WriteReadyForQuery(writer, 'I'); err != nil {
572618
slog.Error("Failed to send ReadyForQuery.", "remote_addr", remoteAddr, "error", err)
573619
return
@@ -589,6 +635,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
589635
// startupResult holds the parsed initial startup message.
590636
type startupResult struct {
591637
sslRequest bool
638+
gssRequest bool
592639
cancelRequest bool
593640
cancelPid int32
594641
cancelSecretKey int32
@@ -649,6 +696,28 @@ func readStartupFromRaw(conn net.Conn) (startupResult, error) {
649696
return startupResult{}, fmt.Errorf("too many negotiation rounds")
650697
}
651698

699+
// readStartupWithGSSFallback accepts a GSSAPI probe, rejects it with 'N',
700+
// and keeps reading startup packets on the same connection so clients can
701+
// continue with SSLRequest/startup without reconnecting.
702+
func readStartupWithGSSFallback(conn net.Conn) (startupResult, error) {
703+
for i := 0; i < 4; i++ {
704+
params, err := readStartupFromRaw(conn)
705+
if err != nil {
706+
return startupResult{}, err
707+
}
708+
709+
if !params.gssRequest {
710+
return params, nil
711+
}
712+
713+
if _, err := conn.Write([]byte{'N'}); err != nil {
714+
return startupResult{}, fmt.Errorf("write GSSAPI rejection: %w", err)
715+
}
716+
}
717+
718+
return startupResult{}, fmt.Errorf("too many GSSAPI startup requests")
719+
}
720+
652721
func fullRead(conn net.Conn, buf []byte) (int, error) {
653722
total := 0
654723
for total < len(buf) {
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package controlplane
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/posthog/duckgres/server"
9+
)
10+
11+
func TestCreateSessionWithRegisteredCancel_CancelQueryCancelsWait(t *testing.T) {
12+
srv := &server.Server{}
13+
server.InitMinimalServer(srv, server.Config{}, nil)
14+
15+
key := server.BackendKey{Pid: 1234, SecretKey: 5678}
16+
17+
started := make(chan struct{})
18+
errCh := make(chan error, 1)
19+
go func() {
20+
_, _, err := createSessionWithRegisteredCancel(
21+
srv,
22+
200*time.Millisecond,
23+
key,
24+
func(ctx context.Context) (int32, *server.FlightExecutor, error) {
25+
close(started)
26+
<-ctx.Done()
27+
return 0, nil, ctx.Err()
28+
},
29+
)
30+
errCh <- err
31+
}()
32+
33+
select {
34+
case <-started:
35+
case <-time.After(2 * time.Second):
36+
t.Fatal("create function did not start")
37+
}
38+
39+
if !srv.CancelQuery(key) {
40+
t.Fatal("expected CancelQuery to find registered query")
41+
}
42+
43+
select {
44+
case err := <-errCh:
45+
if err == nil {
46+
t.Fatal("expected context cancellation error")
47+
}
48+
if err != context.Canceled {
49+
t.Fatalf("expected context.Canceled, got %v", err)
50+
}
51+
case <-time.After(2 * time.Second):
52+
t.Fatal("createSessionWithRegisteredCancel did not return after cancel")
53+
}
54+
55+
if srv.CancelQuery(key) {
56+
t.Fatal("expected query to be unregistered after return")
57+
}
58+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package controlplane
2+
3+
import (
4+
"encoding/binary"
5+
"net"
6+
"testing"
7+
)
8+
9+
func startupPacket(protocolVersion uint32) []byte {
10+
pkt := make([]byte, 8)
11+
binary.BigEndian.PutUint32(pkt[0:4], uint32(len(pkt)))
12+
binary.BigEndian.PutUint32(pkt[4:8], protocolVersion)
13+
return pkt
14+
}
15+
16+
func TestReadStartupWithGSSFallback(t *testing.T) {
17+
serverConn, clientConn := net.Pipe()
18+
defer func() { _ = serverConn.Close() }()
19+
defer func() { _ = clientConn.Close() }()
20+
21+
errCh := make(chan error, 1)
22+
go func() {
23+
// Send GSSENCRequest first.
24+
if _, err := clientConn.Write(startupPacket(80877104)); err != nil {
25+
errCh <- err
26+
return
27+
}
28+
29+
// Server should reject GSS with a single 'N' byte.
30+
resp := make([]byte, 1)
31+
if _, err := clientConn.Read(resp); err != nil {
32+
errCh <- err
33+
return
34+
}
35+
if resp[0] != 'N' {
36+
errCh <- net.InvalidAddrError("expected GSS rejection byte 'N'")
37+
return
38+
}
39+
40+
// Continue negotiation on the same connection with SSLRequest.
41+
if _, err := clientConn.Write(startupPacket(80877103)); err != nil {
42+
errCh <- err
43+
return
44+
}
45+
errCh <- nil
46+
}()
47+
48+
params, err := readStartupWithGSSFallback(serverConn)
49+
if err != nil {
50+
t.Fatalf("readStartupWithGSSFallback returned error: %v", err)
51+
}
52+
if !params.sslRequest {
53+
t.Fatalf("expected sslRequest=true after GSS fallback, got %+v", params)
54+
}
55+
if params.gssRequest {
56+
t.Fatalf("expected gssRequest=false final result, got %+v", params)
57+
}
58+
59+
if err := <-errCh; err != nil {
60+
t.Fatalf("client side negotiation failed: %v", err)
61+
}
62+
}

controlplane/session_mgr.go

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,25 @@ func NewSessionManager(pool *FlightWorkerPool, rebalancer *MemoryRebalancer) *Se
4444
return sm
4545
}
4646

47+
// ReservePID generates a new unique PID for a session.
48+
func (sm *SessionManager) ReservePID() int32 {
49+
return sm.nextPID.Add(1)
50+
}
51+
4752
// CreateSession acquires a worker (reusing an idle one or spawning a new one),
4853
// creates a session on it, and rebalances memory/thread limits across all active sessions.
49-
func (sm *SessionManager) CreateSession(ctx context.Context, username string) (int32, *server.FlightExecutor, error) {
54+
// If pid is 0, a new one is generated.
55+
func (sm *SessionManager) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) {
56+
memoryLimit, threads = sm.resolveSessionLimits(memoryLimit, threads)
57+
5058
// Acquire a worker: reuses idle pre-warmed workers or spawns a new one.
5159
// When max-workers is set, this blocks until a slot is available.
5260
worker, err := sm.pool.AcquireWorker(ctx)
5361
if err != nil {
5462
return 0, nil, fmt.Errorf("acquire worker: %w", err)
5563
}
5664

57-
sessionToken, err := worker.CreateSession(ctx, username)
65+
sessionToken, err := worker.CreateSession(ctx, username, memoryLimit, threads)
5866
if err != nil {
5967
// Clean up the worker we just spawned (but not if it was a pre-warmed idle worker
6068
// that has sessions from other concurrent requests).
@@ -65,7 +73,9 @@ func (sm *SessionManager) CreateSession(ctx context.Context, username string) (i
6573
// Create FlightExecutor sharing the worker's existing gRPC connection
6674
executor := server.NewFlightExecutorFromClient(worker.client, sessionToken)
6775

68-
pid := sm.nextPID.Add(1)
76+
if pid == 0 {
77+
pid = sm.nextPID.Add(1)
78+
}
6979

7080
session := &ManagedSession{
7181
PID: pid,
@@ -81,16 +91,27 @@ func (sm *SessionManager) CreateSession(ctx context.Context, username string) (i
8191

8292
slog.Debug("Session created.", "pid", pid, "worker", worker.ID, "user", username)
8393

84-
// Set memory/thread limits on this session synchronously so it never
85-
// runs with unlimited resources.
94+
// Update other sessions if rebalancing is enabled.
8695
if sm.rebalancer != nil {
87-
sm.rebalancer.SetInitialLimits(ctx, session)
8896
sm.rebalancer.RequestRebalance()
8997
}
9098

9199
return pid, executor, nil
92100
}
93101

102+
func (sm *SessionManager) resolveSessionLimits(memoryLimit string, threads int) (string, int) {
103+
if sm.rebalancer == nil {
104+
return memoryLimit, threads
105+
}
106+
if memoryLimit == "" {
107+
memoryLimit = sm.rebalancer.MemoryLimit()
108+
}
109+
if threads <= 0 {
110+
threads = sm.rebalancer.PerSessionThreads()
111+
}
112+
return memoryLimit, threads
113+
}
114+
94115
// DestroySession destroys a session, retires its dedicated worker, and rebalances
95116
// memory/thread limits across remaining sessions.
96117
func (sm *SessionManager) DestroySession(pid int32) {
@@ -239,4 +260,3 @@ func (sm *SessionManager) AllSessions() []*ManagedSession {
239260
}
240261
return result
241262
}
242-

controlplane/session_mgr_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,29 @@ func TestDestroySessionAfterOnWorkerCrash(t *testing.T) {
251251
t.Fatal("expected 0 sessions after DestroySession")
252252
}
253253
}
254+
255+
func TestResolveSessionLimits_UsesRebalancerDefaultsWhenUnset(t *testing.T) {
256+
r := NewMemoryRebalancer(24*1024*1024*1024, 8, nil, false)
257+
sm := NewSessionManager(&FlightWorkerPool{workers: make(map[int]*ManagedWorker)}, r)
258+
259+
mem, threads := sm.resolveSessionLimits("", 0)
260+
if mem != "24576MB" {
261+
t.Fatalf("expected memory limit 24576MB, got %q", mem)
262+
}
263+
if threads != 8 {
264+
t.Fatalf("expected threads 8, got %d", threads)
265+
}
266+
}
267+
268+
func TestResolveSessionLimits_PreservesExplicitValues(t *testing.T) {
269+
r := NewMemoryRebalancer(24*1024*1024*1024, 8, nil, false)
270+
sm := NewSessionManager(&FlightWorkerPool{workers: make(map[int]*ManagedWorker)}, r)
271+
272+
mem, threads := sm.resolveSessionLimits("1024MB", 2)
273+
if mem != "1024MB" {
274+
t.Fatalf("expected memory limit 1024MB, got %q", mem)
275+
}
276+
if threads != 2 {
277+
t.Fatalf("expected threads 2, got %d", threads)
278+
}
279+
}

0 commit comments

Comments
 (0)