Skip to content

Commit a211ccb

Browse files
authored
Align Flight auth hardening + token flow with PGWire parity (#224)
* fix: align Flight auth hardening and telemetry with PGWire * Align Flight auth behavior * fix: remove SHA256 from credential compare path
1 parent 62340a6 commit a211ccb

File tree

9 files changed

+461
-58
lines changed

9 files changed

+461
-58
lines changed

AGENTS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
- When following a plan file, mark tasks upon completion
99
- When creating new branch from origin/main, do not track origin/main.
1010
- Always run lint before committing.
11-
- Parallelize using subagents when possible.
11+
- Parallelize using subagents when possible.
12+
- Prefer correctness, maintanability, robustness over shortcut implementations

controlplane/control.go

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ func RunControlPlane(cfg ControlPlaneConfig) {
170170
// It is intentionally started after pre-warm to avoid concurrent worker
171171
// creation races between pre-warm and first external Flight requests.
172172
if cfg.FlightPort > 0 {
173-
flightIngress, err := NewFlightIngress(cfg.Host, cfg.FlightPort, tlsCfg, cfg.Users, sessions, FlightIngressConfig{
173+
flightIngress, err := NewFlightIngress(cfg.Host, cfg.FlightPort, tlsCfg, cfg.Users, sessions, cp.rateLimiter, FlightIngressConfig{
174174
SessionIdleTTL: cfg.FlightSessionIdleTTL,
175175
SessionReapTick: cfg.FlightSessionReapInterval,
176176
HandleIdleTTL: cfg.FlightHandleIdleTTL,
@@ -343,19 +343,13 @@ func (cp *ControlPlane) acceptLoop() {
343343
func (cp *ControlPlane) handleConnection(conn net.Conn) {
344344
remoteAddr := conn.RemoteAddr()
345345

346-
// Rate limiting
347-
if msg := cp.rateLimiter.CheckConnection(remoteAddr); msg != "" {
346+
releaseRateLimit, msg := server.BeginRateLimitedAuthAttempt(cp.rateLimiter, remoteAddr)
347+
if msg != "" {
348348
slog.Warn("Connection rejected.", "remote_addr", remoteAddr, "reason", msg)
349349
_ = conn.Close()
350350
return
351351
}
352-
353-
if !cp.rateLimiter.RegisterConnection(remoteAddr) {
354-
slog.Warn("Connection rejected: rate limit.", "remote_addr", remoteAddr)
355-
_ = conn.Close()
356-
return
357-
}
358-
defer cp.rateLimiter.UnregisterConnection(remoteAddr)
352+
defer releaseRateLimit()
359353

360354
// Read startup message to determine SSL vs cancel
361355
params, err := readStartupFromRaw(conn)
@@ -376,6 +370,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
376370
// Require SSL
377371
if !params.sslRequest {
378372
slog.Warn("Connection rejected: SSL required.", "remote_addr", remoteAddr)
373+
server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
379374
_ = conn.Close()
380375
return
381376
}
@@ -423,20 +418,12 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
423418
database := startupParams["database"]
424419

425420
if username == "" {
421+
server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
426422
_ = server.WriteErrorResponse(writer, "FATAL", "28000", "no user specified")
427423
_ = writer.Flush()
428424
return
429425
}
430426

431-
// Look up expected password for this user
432-
expectedPassword, ok := cp.cfg.Users[username]
433-
if !ok {
434-
slog.Warn("Unknown user.", "user", username, "remote_addr", remoteAddr)
435-
_ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed")
436-
_ = writer.Flush()
437-
return
438-
}
439-
440427
// Request password
441428
if err := server.WriteAuthCleartextPassword(writer); err != nil {
442429
slog.Error("Failed to request password.", "remote_addr", remoteAddr, "error", err)
@@ -455,15 +442,19 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
455442
}
456443

457444
if msgType != 'p' {
445+
server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
458446
_ = server.WriteErrorResponse(writer, "FATAL", "28000", "expected password message")
459447
_ = writer.Flush()
460448
return
461449
}
462450

463451
password := string(bytes.TrimRight(body, "\x00"))
464-
if password != expectedPassword {
452+
if !server.ValidateUserPassword(cp.cfg.Users, username, password) {
465453
slog.Warn("Authentication failed.", "user", username, "remote_addr", remoteAddr)
466-
cp.rateLimiter.RecordFailedAuth(remoteAddr)
454+
banned := server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
455+
if banned {
456+
slog.Warn("IP banned after too many failed auth attempts.", "remote_addr", remoteAddr)
457+
}
467458
_ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed")
468459
_ = writer.Flush()
469460
return
@@ -475,7 +466,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
475466
return
476467
}
477468

478-
cp.rateLimiter.RecordSuccessfulAuth(remoteAddr)
469+
server.RecordSuccessfulAuthAttempt(cp.rateLimiter, remoteAddr)
479470
slog.Info("User authenticated.", "user", username, "remote_addr", remoteAddr)
480471

481472
// Create session on a worker
@@ -740,7 +731,7 @@ func (cp *ControlPlane) recoverFlightIngressAfterFailedReload() {
740731
return
741732
}
742733

743-
flightIngress, err := NewFlightIngress(cp.cfg.Host, cp.cfg.FlightPort, cp.tlsConfig, cp.cfg.Users, cp.sessions, FlightIngressConfig{
734+
flightIngress, err := NewFlightIngress(cp.cfg.Host, cp.cfg.FlightPort, cp.tlsConfig, cp.cfg.Users, cp.sessions, cp.rateLimiter, FlightIngressConfig{
744735
SessionIdleTTL: cp.cfg.FlightSessionIdleTTL,
745736
SessionReapTick: cp.cfg.FlightSessionReapInterval,
746737
HandleIdleTTL: cp.cfg.FlightHandleIdleTTL,

controlplane/flight_ingress.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"crypto/tls"
55
"errors"
66

7+
"github.com/posthog/duckgres/server"
78
"github.com/posthog/duckgres/server/flightsqlingress"
89
)
910

@@ -12,11 +13,12 @@ type FlightIngressConfig = flightsqlingress.Config
1213
type FlightIngress = flightsqlingress.FlightIngress
1314

1415
// NewFlightIngress creates a control-plane Flight SQL ingress listener.
15-
func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[string]string, sm *SessionManager, cfg FlightIngressConfig) (*FlightIngress, error) {
16+
func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[string]string, sm *SessionManager, rateLimiter *server.RateLimiter, cfg FlightIngressConfig) (*FlightIngress, error) {
1617
return flightsqlingress.NewFlightIngress(host, port, tlsConfig, users, sm, cfg, flightsqlingress.Options{
1718
IsMaxWorkersError: func(err error) bool {
1819
return errors.Is(err, ErrMaxWorkersReached)
1920
},
21+
RateLimiter: rateLimiter,
2022
Hooks: flightsqlingress.Hooks{
2123
OnSessionCountChanged: observeFlightAuthSessions,
2224
OnSessionsReaped: observeFlightSessionsReaped,

controlplane/flight_ingress_adapter_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package controlplane
33
import "testing"
44

55
func TestNewFlightIngressAdapterValidation(t *testing.T) {
6-
_, err := NewFlightIngress("127.0.0.1", 0, nil, map[string]string{}, nil, FlightIngressConfig{})
6+
_, err := NewFlightIngress("127.0.0.1", 0, nil, map[string]string{}, nil, nil, FlightIngressConfig{})
77
if err == nil {
88
t.Fatalf("expected validation error for invalid port")
99
}

server/auth_policy.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package server
2+
3+
import (
4+
"crypto/subtle"
5+
"net"
6+
)
7+
8+
const invalidPasswordSentinel = "__duckgres_invalid_password_sentinel__"
9+
10+
// BeginRateLimitedAuthAttempt enforces rate-limit policy before an auth attempt.
11+
// The returned release function must be called once the attempt is complete.
12+
func BeginRateLimitedAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) (release func(), rejectReason string) {
13+
release = func() {}
14+
if rateLimiter == nil {
15+
return release, ""
16+
}
17+
18+
if msg := rateLimiter.CheckConnection(remoteAddr); msg != "" {
19+
rateLimitRejectsCounter.Inc()
20+
return release, msg
21+
}
22+
if !rateLimiter.RegisterConnection(remoteAddr) {
23+
rateLimitRejectsCounter.Inc()
24+
if msg := rateLimiter.CheckConnection(remoteAddr); msg != "" {
25+
return release, msg
26+
}
27+
return release, "too many connections from your IP address"
28+
}
29+
30+
return func() {
31+
rateLimiter.UnregisterConnection(remoteAddr)
32+
}, ""
33+
}
34+
35+
// RecordFailedAuthAttempt records auth telemetry and updates rate-limit state.
36+
// Returns true when this failure causes the source IP to be banned.
37+
func RecordFailedAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) bool {
38+
authFailuresCounter.Inc()
39+
if rateLimiter == nil {
40+
return false
41+
}
42+
return rateLimiter.RecordFailedAuth(remoteAddr)
43+
}
44+
45+
// RecordSuccessfulAuthAttempt clears failure tracking after successful auth.
46+
func RecordSuccessfulAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) {
47+
if rateLimiter == nil {
48+
return
49+
}
50+
rateLimiter.RecordSuccessfulAuth(remoteAddr)
51+
}
52+
53+
// ValidateUserPassword validates username/password without leaking user existence
54+
// via credential-compare timing differences.
55+
func ValidateUserPassword(users map[string]string, username, password string) bool {
56+
expectedPassword, userFound := users[username]
57+
if !userFound {
58+
expectedPassword = invalidPasswordSentinel
59+
}
60+
61+
passwordMatches := constantTimeStringEqual(password, expectedPassword)
62+
return userFound && passwordMatches
63+
}
64+
65+
func constantTimeStringEqual(a, b string) bool {
66+
ab := []byte(a)
67+
bb := []byte(b)
68+
69+
maxLen := len(ab)
70+
if len(bb) > maxLen {
71+
maxLen = len(bb)
72+
}
73+
74+
var diff byte
75+
for i := 0; i < maxLen; i++ {
76+
var av byte
77+
var bv byte
78+
if i < len(ab) {
79+
av = ab[i]
80+
}
81+
if i < len(bb) {
82+
bv = bb[i]
83+
}
84+
diff |= av ^ bv
85+
}
86+
87+
lengthsEqual := subtle.ConstantTimeEq(int32(len(ab)), int32(len(bb))) == 1
88+
return lengthsEqual && diff == 0
89+
}

server/auth_policy_test.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package server
2+
3+
import (
4+
"net"
5+
"testing"
6+
"time"
7+
8+
"github.com/prometheus/client_golang/prometheus"
9+
dto "github.com/prometheus/client_model/go"
10+
)
11+
12+
func counterMetricValue(t *testing.T, metricName string) float64 {
13+
t.Helper()
14+
families, err := prometheus.DefaultGatherer.Gather()
15+
if err != nil {
16+
t.Fatalf("failed to gather metrics: %v", err)
17+
}
18+
for _, fam := range families {
19+
if fam.GetName() != metricName {
20+
continue
21+
}
22+
if fam.GetType() != dto.MetricType_COUNTER {
23+
t.Fatalf("metric %q is not a counter", metricName)
24+
}
25+
var total float64
26+
for _, metric := range fam.GetMetric() {
27+
total += metric.GetCounter().GetValue()
28+
}
29+
return total
30+
}
31+
t.Fatalf("metric %q not found", metricName)
32+
return 0
33+
}
34+
35+
func TestValidateUserPassword(t *testing.T) {
36+
users := map[string]string{"postgres": "postgres"}
37+
38+
if !ValidateUserPassword(users, "postgres", "postgres") {
39+
t.Fatalf("expected valid credentials to pass")
40+
}
41+
if ValidateUserPassword(users, "postgres", "wrong") {
42+
t.Fatalf("expected wrong password to fail")
43+
}
44+
if ValidateUserPassword(users, "unknown", "postgres") {
45+
t.Fatalf("expected unknown user to fail")
46+
}
47+
}
48+
49+
func TestRecordFailedAuthAttemptIncrementsMetricAndBans(t *testing.T) {
50+
addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.10"), Port: 41000}
51+
rl := NewRateLimiter(RateLimitConfig{
52+
MaxFailedAttempts: 1,
53+
FailedAttemptWindow: time.Minute,
54+
BanDuration: time.Hour,
55+
MaxConnectionsPerIP: 10,
56+
})
57+
58+
before := counterMetricValue(t, "duckgres_auth_failures_total")
59+
banned := RecordFailedAuthAttempt(rl, addr)
60+
after := counterMetricValue(t, "duckgres_auth_failures_total")
61+
62+
if !banned {
63+
t.Fatalf("expected failed auth attempt to ban when threshold is 1")
64+
}
65+
if after-before != 1 {
66+
t.Fatalf("expected duckgres_auth_failures_total delta 1, got %.0f", after-before)
67+
}
68+
}
69+
70+
func TestBeginRateLimitedAuthAttemptRejectsBannedAndIncrementsMetric(t *testing.T) {
71+
addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.11"), Port: 41001}
72+
rl := NewRateLimiter(RateLimitConfig{
73+
MaxFailedAttempts: 1,
74+
FailedAttemptWindow: time.Minute,
75+
BanDuration: time.Hour,
76+
MaxConnectionsPerIP: 10,
77+
})
78+
rl.RecordFailedAuth(addr)
79+
80+
before := counterMetricValue(t, "duckgres_rate_limit_rejects_total")
81+
release, reason := BeginRateLimitedAuthAttempt(rl, addr)
82+
release()
83+
after := counterMetricValue(t, "duckgres_rate_limit_rejects_total")
84+
85+
if reason == "" {
86+
t.Fatalf("expected non-empty rejection reason for banned client")
87+
}
88+
if after-before != 1 {
89+
t.Fatalf("expected duckgres_rate_limit_rejects_total delta 1, got %.0f", after-before)
90+
}
91+
}
92+
93+
func TestBeginRateLimitedAuthAttemptRegistersAndReleases(t *testing.T) {
94+
addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.12"), Port: 41002}
95+
rl := NewRateLimiter(RateLimitConfig{
96+
MaxFailedAttempts: 5,
97+
FailedAttemptWindow: time.Minute,
98+
BanDuration: time.Hour,
99+
MaxConnectionsPerIP: 1,
100+
})
101+
102+
release1, reason1 := BeginRateLimitedAuthAttempt(rl, addr)
103+
if reason1 != "" {
104+
t.Fatalf("unexpected first attempt rejection: %q", reason1)
105+
}
106+
107+
release2, reason2 := BeginRateLimitedAuthAttempt(rl, addr)
108+
release2()
109+
if reason2 == "" {
110+
t.Fatalf("expected second concurrent attempt to be rejected")
111+
}
112+
113+
release1()
114+
115+
release3, reason3 := BeginRateLimitedAuthAttempt(rl, addr)
116+
release3()
117+
if reason3 != "" {
118+
t.Fatalf("expected third attempt to succeed after release, got %q", reason3)
119+
}
120+
}

0 commit comments

Comments
 (0)