Skip to content

Commit cc3bad2

Browse files
committed
Merge main into pr-225 and resolve conflicts
Resolved conflicts in: - controlplane/control.go: Combined rate limiter from main with queue timeout config from PR. - controlplane/flight_ingress.go: Included rate limiter, removed obsolete IsMaxWorkersError. - server/flightsqlingress/ingress.go: Updated Options to include RateLimiter. - server/flightsqlingress/ingress_test.go: Kept new rate limiter tests from main, removed deleted retry tests.
2 parents 767f46b + a211ccb commit cc3bad2

File tree

9 files changed

+462
-59
lines changed

9 files changed

+462
-59
lines changed

AGENTS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
- When following a plan file, mark tasks upon completion
99
- When creating new branch from origin/main, do not track origin/main.
1010
- Always run lint before committing.
11-
- Parallelize using subagents when possible.
11+
- Parallelize using subagents when possible.
12+
- Prefer correctness, maintanability, robustness over shortcut implementations

controlplane/control.go

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ func RunControlPlane(cfg ControlPlaneConfig) {
173173
// It is intentionally started after pre-warm to avoid concurrent worker
174174
// creation races between pre-warm and first external Flight requests.
175175
if cfg.FlightPort > 0 {
176-
flightIngress, err := NewFlightIngress(cfg.Host, cfg.FlightPort, tlsCfg, cfg.Users, sessions, FlightIngressConfig{
176+
flightIngress, err := NewFlightIngress(cfg.Host, cfg.FlightPort, tlsCfg, cfg.Users, sessions, cp.rateLimiter, FlightIngressConfig{
177177
SessionIdleTTL: cfg.FlightSessionIdleTTL,
178178
SessionReapTick: cfg.FlightSessionReapInterval,
179179
HandleIdleTTL: cfg.FlightHandleIdleTTL,
@@ -348,19 +348,13 @@ func (cp *ControlPlane) acceptLoop() {
348348
func (cp *ControlPlane) handleConnection(conn net.Conn) {
349349
remoteAddr := conn.RemoteAddr()
350350

351-
// Rate limiting
352-
if msg := cp.rateLimiter.CheckConnection(remoteAddr); msg != "" {
351+
releaseRateLimit, msg := server.BeginRateLimitedAuthAttempt(cp.rateLimiter, remoteAddr)
352+
if msg != "" {
353353
slog.Warn("Connection rejected.", "remote_addr", remoteAddr, "reason", msg)
354354
_ = conn.Close()
355355
return
356356
}
357-
358-
if !cp.rateLimiter.RegisterConnection(remoteAddr) {
359-
slog.Warn("Connection rejected: rate limit.", "remote_addr", remoteAddr)
360-
_ = conn.Close()
361-
return
362-
}
363-
defer cp.rateLimiter.UnregisterConnection(remoteAddr)
357+
defer releaseRateLimit()
364358

365359
// Read startup message to determine SSL vs cancel
366360
params, err := readStartupFromRaw(conn)
@@ -381,6 +375,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
381375
// Require SSL
382376
if !params.sslRequest {
383377
slog.Warn("Connection rejected: SSL required.", "remote_addr", remoteAddr)
378+
server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
384379
_ = conn.Close()
385380
return
386381
}
@@ -428,20 +423,12 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
428423
database := startupParams["database"]
429424

430425
if username == "" {
426+
server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
431427
_ = server.WriteErrorResponse(writer, "FATAL", "28000", "no user specified")
432428
_ = writer.Flush()
433429
return
434430
}
435431

436-
// Look up expected password for this user
437-
expectedPassword, ok := cp.cfg.Users[username]
438-
if !ok {
439-
slog.Warn("Unknown user.", "user", username, "remote_addr", remoteAddr)
440-
_ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed")
441-
_ = writer.Flush()
442-
return
443-
}
444-
445432
// Request password
446433
if err := server.WriteAuthCleartextPassword(writer); err != nil {
447434
slog.Error("Failed to request password.", "remote_addr", remoteAddr, "error", err)
@@ -460,15 +447,19 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
460447
}
461448

462449
if msgType != 'p' {
450+
server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
463451
_ = server.WriteErrorResponse(writer, "FATAL", "28000", "expected password message")
464452
_ = writer.Flush()
465453
return
466454
}
467455

468456
password := string(bytes.TrimRight(body, "\x00"))
469-
if password != expectedPassword {
457+
if !server.ValidateUserPassword(cp.cfg.Users, username, password) {
470458
slog.Warn("Authentication failed.", "user", username, "remote_addr", remoteAddr)
471-
cp.rateLimiter.RecordFailedAuth(remoteAddr)
459+
banned := server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
460+
if banned {
461+
slog.Warn("IP banned after too many failed auth attempts.", "remote_addr", remoteAddr)
462+
}
472463
_ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed")
473464
_ = writer.Flush()
474465
return
@@ -480,7 +471,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
480471
return
481472
}
482473

483-
cp.rateLimiter.RecordSuccessfulAuth(remoteAddr)
474+
server.RecordSuccessfulAuthAttempt(cp.rateLimiter, remoteAddr)
484475
slog.Info("User authenticated.", "user", username, "remote_addr", remoteAddr)
485476

486477
// Create session on a worker. The timeout controls how long we wait in the
@@ -747,7 +738,7 @@ func (cp *ControlPlane) recoverFlightIngressAfterFailedReload() {
747738
return
748739
}
749740

750-
flightIngress, err := NewFlightIngress(cp.cfg.Host, cp.cfg.FlightPort, cp.tlsConfig, cp.cfg.Users, cp.sessions, FlightIngressConfig{
741+
flightIngress, err := NewFlightIngress(cp.cfg.Host, cp.cfg.FlightPort, cp.tlsConfig, cp.cfg.Users, cp.sessions, cp.rateLimiter, FlightIngressConfig{
751742
SessionIdleTTL: cp.cfg.FlightSessionIdleTTL,
752743
SessionReapTick: cp.cfg.FlightSessionReapInterval,
753744
HandleIdleTTL: cp.cfg.FlightHandleIdleTTL,

controlplane/flight_ingress.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package controlplane
33
import (
44
"crypto/tls"
55

6+
"github.com/posthog/duckgres/server"
67
"github.com/posthog/duckgres/server/flightsqlingress"
78
)
89

@@ -11,8 +12,9 @@ type FlightIngressConfig = flightsqlingress.Config
1112
type FlightIngress = flightsqlingress.FlightIngress
1213

1314
// NewFlightIngress creates a control-plane Flight SQL ingress listener.
14-
func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[string]string, sm *SessionManager, cfg FlightIngressConfig) (*FlightIngress, error) {
15+
func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[string]string, sm *SessionManager, rateLimiter *server.RateLimiter, cfg FlightIngressConfig) (*FlightIngress, error) {
1516
return flightsqlingress.NewFlightIngress(host, port, tlsConfig, users, sm, cfg, flightsqlingress.Options{
17+
RateLimiter: rateLimiter,
1618
Hooks: flightsqlingress.Hooks{
1719
OnSessionCountChanged: observeFlightAuthSessions,
1820
OnSessionsReaped: observeFlightSessionsReaped,

controlplane/flight_ingress_adapter_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package controlplane
33
import "testing"
44

55
func TestNewFlightIngressAdapterValidation(t *testing.T) {
6-
_, err := NewFlightIngress("127.0.0.1", 0, nil, map[string]string{}, nil, FlightIngressConfig{})
6+
_, err := NewFlightIngress("127.0.0.1", 0, nil, map[string]string{}, nil, nil, FlightIngressConfig{})
77
if err == nil {
88
t.Fatalf("expected validation error for invalid port")
99
}

server/auth_policy.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package server
2+
3+
import (
4+
"crypto/subtle"
5+
"net"
6+
)
7+
8+
const invalidPasswordSentinel = "__duckgres_invalid_password_sentinel__"
9+
10+
// BeginRateLimitedAuthAttempt enforces rate-limit policy before an auth attempt.
11+
// The returned release function must be called once the attempt is complete.
12+
func BeginRateLimitedAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) (release func(), rejectReason string) {
13+
release = func() {}
14+
if rateLimiter == nil {
15+
return release, ""
16+
}
17+
18+
if msg := rateLimiter.CheckConnection(remoteAddr); msg != "" {
19+
rateLimitRejectsCounter.Inc()
20+
return release, msg
21+
}
22+
if !rateLimiter.RegisterConnection(remoteAddr) {
23+
rateLimitRejectsCounter.Inc()
24+
if msg := rateLimiter.CheckConnection(remoteAddr); msg != "" {
25+
return release, msg
26+
}
27+
return release, "too many connections from your IP address"
28+
}
29+
30+
return func() {
31+
rateLimiter.UnregisterConnection(remoteAddr)
32+
}, ""
33+
}
34+
35+
// RecordFailedAuthAttempt records auth telemetry and updates rate-limit state.
36+
// Returns true when this failure causes the source IP to be banned.
37+
func RecordFailedAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) bool {
38+
authFailuresCounter.Inc()
39+
if rateLimiter == nil {
40+
return false
41+
}
42+
return rateLimiter.RecordFailedAuth(remoteAddr)
43+
}
44+
45+
// RecordSuccessfulAuthAttempt clears failure tracking after successful auth.
46+
func RecordSuccessfulAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) {
47+
if rateLimiter == nil {
48+
return
49+
}
50+
rateLimiter.RecordSuccessfulAuth(remoteAddr)
51+
}
52+
53+
// ValidateUserPassword validates username/password without leaking user existence
54+
// via credential-compare timing differences.
55+
func ValidateUserPassword(users map[string]string, username, password string) bool {
56+
expectedPassword, userFound := users[username]
57+
if !userFound {
58+
expectedPassword = invalidPasswordSentinel
59+
}
60+
61+
passwordMatches := constantTimeStringEqual(password, expectedPassword)
62+
return userFound && passwordMatches
63+
}
64+
65+
func constantTimeStringEqual(a, b string) bool {
66+
ab := []byte(a)
67+
bb := []byte(b)
68+
69+
maxLen := len(ab)
70+
if len(bb) > maxLen {
71+
maxLen = len(bb)
72+
}
73+
74+
var diff byte
75+
for i := 0; i < maxLen; i++ {
76+
var av byte
77+
var bv byte
78+
if i < len(ab) {
79+
av = ab[i]
80+
}
81+
if i < len(bb) {
82+
bv = bb[i]
83+
}
84+
diff |= av ^ bv
85+
}
86+
87+
lengthsEqual := subtle.ConstantTimeEq(int32(len(ab)), int32(len(bb))) == 1
88+
return lengthsEqual && diff == 0
89+
}

server/auth_policy_test.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package server
2+
3+
import (
4+
"net"
5+
"testing"
6+
"time"
7+
8+
"github.com/prometheus/client_golang/prometheus"
9+
dto "github.com/prometheus/client_model/go"
10+
)
11+
12+
func counterMetricValue(t *testing.T, metricName string) float64 {
13+
t.Helper()
14+
families, err := prometheus.DefaultGatherer.Gather()
15+
if err != nil {
16+
t.Fatalf("failed to gather metrics: %v", err)
17+
}
18+
for _, fam := range families {
19+
if fam.GetName() != metricName {
20+
continue
21+
}
22+
if fam.GetType() != dto.MetricType_COUNTER {
23+
t.Fatalf("metric %q is not a counter", metricName)
24+
}
25+
var total float64
26+
for _, metric := range fam.GetMetric() {
27+
total += metric.GetCounter().GetValue()
28+
}
29+
return total
30+
}
31+
t.Fatalf("metric %q not found", metricName)
32+
return 0
33+
}
34+
35+
func TestValidateUserPassword(t *testing.T) {
36+
users := map[string]string{"postgres": "postgres"}
37+
38+
if !ValidateUserPassword(users, "postgres", "postgres") {
39+
t.Fatalf("expected valid credentials to pass")
40+
}
41+
if ValidateUserPassword(users, "postgres", "wrong") {
42+
t.Fatalf("expected wrong password to fail")
43+
}
44+
if ValidateUserPassword(users, "unknown", "postgres") {
45+
t.Fatalf("expected unknown user to fail")
46+
}
47+
}
48+
49+
func TestRecordFailedAuthAttemptIncrementsMetricAndBans(t *testing.T) {
50+
addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.10"), Port: 41000}
51+
rl := NewRateLimiter(RateLimitConfig{
52+
MaxFailedAttempts: 1,
53+
FailedAttemptWindow: time.Minute,
54+
BanDuration: time.Hour,
55+
MaxConnectionsPerIP: 10,
56+
})
57+
58+
before := counterMetricValue(t, "duckgres_auth_failures_total")
59+
banned := RecordFailedAuthAttempt(rl, addr)
60+
after := counterMetricValue(t, "duckgres_auth_failures_total")
61+
62+
if !banned {
63+
t.Fatalf("expected failed auth attempt to ban when threshold is 1")
64+
}
65+
if after-before != 1 {
66+
t.Fatalf("expected duckgres_auth_failures_total delta 1, got %.0f", after-before)
67+
}
68+
}
69+
70+
func TestBeginRateLimitedAuthAttemptRejectsBannedAndIncrementsMetric(t *testing.T) {
71+
addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.11"), Port: 41001}
72+
rl := NewRateLimiter(RateLimitConfig{
73+
MaxFailedAttempts: 1,
74+
FailedAttemptWindow: time.Minute,
75+
BanDuration: time.Hour,
76+
MaxConnectionsPerIP: 10,
77+
})
78+
rl.RecordFailedAuth(addr)
79+
80+
before := counterMetricValue(t, "duckgres_rate_limit_rejects_total")
81+
release, reason := BeginRateLimitedAuthAttempt(rl, addr)
82+
release()
83+
after := counterMetricValue(t, "duckgres_rate_limit_rejects_total")
84+
85+
if reason == "" {
86+
t.Fatalf("expected non-empty rejection reason for banned client")
87+
}
88+
if after-before != 1 {
89+
t.Fatalf("expected duckgres_rate_limit_rejects_total delta 1, got %.0f", after-before)
90+
}
91+
}
92+
93+
func TestBeginRateLimitedAuthAttemptRegistersAndReleases(t *testing.T) {
94+
addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.12"), Port: 41002}
95+
rl := NewRateLimiter(RateLimitConfig{
96+
MaxFailedAttempts: 5,
97+
FailedAttemptWindow: time.Minute,
98+
BanDuration: time.Hour,
99+
MaxConnectionsPerIP: 1,
100+
})
101+
102+
release1, reason1 := BeginRateLimitedAuthAttempt(rl, addr)
103+
if reason1 != "" {
104+
t.Fatalf("unexpected first attempt rejection: %q", reason1)
105+
}
106+
107+
release2, reason2 := BeginRateLimitedAuthAttempt(rl, addr)
108+
release2()
109+
if reason2 == "" {
110+
t.Fatalf("expected second concurrent attempt to be rejected")
111+
}
112+
113+
release1()
114+
115+
release3, reason3 := BeginRateLimitedAuthAttempt(rl, addr)
116+
release3()
117+
if reason3 != "" {
118+
t.Fatalf("expected third attempt to succeed after release, got %q", reason3)
119+
}
120+
}

0 commit comments

Comments
 (0)