Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
- When following a plan file, mark tasks upon completion
- When creating new branch from origin/main, do not track origin/main.
- Always run lint before committing.
- Parallelize using subagents when possible.
- Parallelize using subagents when possible.
- Prefer correctness, maintanability, robustness over shortcut implementations
37 changes: 14 additions & 23 deletions controlplane/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func RunControlPlane(cfg ControlPlaneConfig) {
// It is intentionally started after pre-warm to avoid concurrent worker
// creation races between pre-warm and first external Flight requests.
if cfg.FlightPort > 0 {
flightIngress, err := NewFlightIngress(cfg.Host, cfg.FlightPort, tlsCfg, cfg.Users, sessions, FlightIngressConfig{
flightIngress, err := NewFlightIngress(cfg.Host, cfg.FlightPort, tlsCfg, cfg.Users, sessions, cp.rateLimiter, FlightIngressConfig{
SessionIdleTTL: cfg.FlightSessionIdleTTL,
SessionReapTick: cfg.FlightSessionReapInterval,
HandleIdleTTL: cfg.FlightHandleIdleTTL,
Expand Down Expand Up @@ -343,19 +343,13 @@ func (cp *ControlPlane) acceptLoop() {
func (cp *ControlPlane) handleConnection(conn net.Conn) {
remoteAddr := conn.RemoteAddr()

// Rate limiting
if msg := cp.rateLimiter.CheckConnection(remoteAddr); msg != "" {
releaseRateLimit, msg := server.BeginRateLimitedAuthAttempt(cp.rateLimiter, remoteAddr)
if msg != "" {
slog.Warn("Connection rejected.", "remote_addr", remoteAddr, "reason", msg)
_ = conn.Close()
return
}

if !cp.rateLimiter.RegisterConnection(remoteAddr) {
slog.Warn("Connection rejected: rate limit.", "remote_addr", remoteAddr)
_ = conn.Close()
return
}
defer cp.rateLimiter.UnregisterConnection(remoteAddr)
defer releaseRateLimit()

// Read startup message to determine SSL vs cancel
params, err := readStartupFromRaw(conn)
Expand All @@ -376,6 +370,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
// Require SSL
if !params.sslRequest {
slog.Warn("Connection rejected: SSL required.", "remote_addr", remoteAddr)
server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
_ = conn.Close()
return
}
Expand Down Expand Up @@ -423,20 +418,12 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
database := startupParams["database"]

if username == "" {
server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
_ = server.WriteErrorResponse(writer, "FATAL", "28000", "no user specified")
_ = writer.Flush()
return
}

// Look up expected password for this user
expectedPassword, ok := cp.cfg.Users[username]
if !ok {
slog.Warn("Unknown user.", "user", username, "remote_addr", remoteAddr)
_ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed")
_ = writer.Flush()
return
}

// Request password
if err := server.WriteAuthCleartextPassword(writer); err != nil {
slog.Error("Failed to request password.", "remote_addr", remoteAddr, "error", err)
Expand All @@ -455,15 +442,19 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
}

if msgType != 'p' {
server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
_ = server.WriteErrorResponse(writer, "FATAL", "28000", "expected password message")
_ = writer.Flush()
return
}

password := string(bytes.TrimRight(body, "\x00"))
if password != expectedPassword {
if !server.ValidateUserPassword(cp.cfg.Users, username, password) {
slog.Warn("Authentication failed.", "user", username, "remote_addr", remoteAddr)
cp.rateLimiter.RecordFailedAuth(remoteAddr)
banned := server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr)
if banned {
slog.Warn("IP banned after too many failed auth attempts.", "remote_addr", remoteAddr)
}
_ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed")
_ = writer.Flush()
return
Expand All @@ -475,7 +466,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
return
}

cp.rateLimiter.RecordSuccessfulAuth(remoteAddr)
server.RecordSuccessfulAuthAttempt(cp.rateLimiter, remoteAddr)
slog.Info("User authenticated.", "user", username, "remote_addr", remoteAddr)

// Create session on a worker
Expand Down Expand Up @@ -740,7 +731,7 @@ func (cp *ControlPlane) recoverFlightIngressAfterFailedReload() {
return
}

flightIngress, err := NewFlightIngress(cp.cfg.Host, cp.cfg.FlightPort, cp.tlsConfig, cp.cfg.Users, cp.sessions, FlightIngressConfig{
flightIngress, err := NewFlightIngress(cp.cfg.Host, cp.cfg.FlightPort, cp.tlsConfig, cp.cfg.Users, cp.sessions, cp.rateLimiter, FlightIngressConfig{
SessionIdleTTL: cp.cfg.FlightSessionIdleTTL,
SessionReapTick: cp.cfg.FlightSessionReapInterval,
HandleIdleTTL: cp.cfg.FlightHandleIdleTTL,
Expand Down
4 changes: 3 additions & 1 deletion controlplane/flight_ingress.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/tls"
"errors"

"github.com/posthog/duckgres/server"
"github.com/posthog/duckgres/server/flightsqlingress"
)

Expand All @@ -12,11 +13,12 @@ type FlightIngressConfig = flightsqlingress.Config
type FlightIngress = flightsqlingress.FlightIngress

// NewFlightIngress creates a control-plane Flight SQL ingress listener.
func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[string]string, sm *SessionManager, cfg FlightIngressConfig) (*FlightIngress, error) {
func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[string]string, sm *SessionManager, rateLimiter *server.RateLimiter, cfg FlightIngressConfig) (*FlightIngress, error) {
return flightsqlingress.NewFlightIngress(host, port, tlsConfig, users, sm, cfg, flightsqlingress.Options{
IsMaxWorkersError: func(err error) bool {
return errors.Is(err, ErrMaxWorkersReached)
},
RateLimiter: rateLimiter,
Hooks: flightsqlingress.Hooks{
OnSessionCountChanged: observeFlightAuthSessions,
OnSessionsReaped: observeFlightSessionsReaped,
Expand Down
2 changes: 1 addition & 1 deletion controlplane/flight_ingress_adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package controlplane
import "testing"

func TestNewFlightIngressAdapterValidation(t *testing.T) {
_, err := NewFlightIngress("127.0.0.1", 0, nil, map[string]string{}, nil, FlightIngressConfig{})
_, err := NewFlightIngress("127.0.0.1", 0, nil, map[string]string{}, nil, nil, FlightIngressConfig{})
if err == nil {
t.Fatalf("expected validation error for invalid port")
}
Expand Down
89 changes: 89 additions & 0 deletions server/auth_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package server

import (
"crypto/subtle"
"net"
)

const invalidPasswordSentinel = "__duckgres_invalid_password_sentinel__"

// BeginRateLimitedAuthAttempt enforces rate-limit policy before an auth attempt.
// The returned release function must be called once the attempt is complete.
func BeginRateLimitedAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) (release func(), rejectReason string) {
release = func() {}
if rateLimiter == nil {
return release, ""
}

if msg := rateLimiter.CheckConnection(remoteAddr); msg != "" {
rateLimitRejectsCounter.Inc()
return release, msg
}
if !rateLimiter.RegisterConnection(remoteAddr) {
rateLimitRejectsCounter.Inc()
if msg := rateLimiter.CheckConnection(remoteAddr); msg != "" {
return release, msg
}
return release, "too many connections from your IP address"
}

return func() {
rateLimiter.UnregisterConnection(remoteAddr)
}, ""
}

// RecordFailedAuthAttempt records auth telemetry and updates rate-limit state.
// Returns true when this failure causes the source IP to be banned.
func RecordFailedAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) bool {
authFailuresCounter.Inc()
if rateLimiter == nil {
return false
}
return rateLimiter.RecordFailedAuth(remoteAddr)
}

// RecordSuccessfulAuthAttempt clears failure tracking after successful auth.
func RecordSuccessfulAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) {
if rateLimiter == nil {
return
}
rateLimiter.RecordSuccessfulAuth(remoteAddr)
}

// ValidateUserPassword validates username/password without leaking user existence
// via credential-compare timing differences.
func ValidateUserPassword(users map[string]string, username, password string) bool {
expectedPassword, userFound := users[username]
if !userFound {
expectedPassword = invalidPasswordSentinel
}

passwordMatches := constantTimeStringEqual(password, expectedPassword)
return userFound && passwordMatches
}

func constantTimeStringEqual(a, b string) bool {
ab := []byte(a)
bb := []byte(b)

maxLen := len(ab)
if len(bb) > maxLen {
maxLen = len(bb)
}

var diff byte
for i := 0; i < maxLen; i++ {
var av byte
var bv byte
if i < len(ab) {
av = ab[i]
}
if i < len(bb) {
bv = bb[i]
}
diff |= av ^ bv
}

lengthsEqual := subtle.ConstantTimeEq(int32(len(ab)), int32(len(bb))) == 1
return lengthsEqual && diff == 0
}
120 changes: 120 additions & 0 deletions server/auth_policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package server

import (
"net"
"testing"
"time"

"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
)

func counterMetricValue(t *testing.T, metricName string) float64 {
t.Helper()
families, err := prometheus.DefaultGatherer.Gather()
if err != nil {
t.Fatalf("failed to gather metrics: %v", err)
}
for _, fam := range families {
if fam.GetName() != metricName {
continue
}
if fam.GetType() != dto.MetricType_COUNTER {
t.Fatalf("metric %q is not a counter", metricName)
}
var total float64
for _, metric := range fam.GetMetric() {
total += metric.GetCounter().GetValue()
}
return total
}
t.Fatalf("metric %q not found", metricName)
return 0
}

func TestValidateUserPassword(t *testing.T) {
users := map[string]string{"postgres": "postgres"}

if !ValidateUserPassword(users, "postgres", "postgres") {
t.Fatalf("expected valid credentials to pass")
}
if ValidateUserPassword(users, "postgres", "wrong") {
t.Fatalf("expected wrong password to fail")
}
if ValidateUserPassword(users, "unknown", "postgres") {
t.Fatalf("expected unknown user to fail")
}
}

func TestRecordFailedAuthAttemptIncrementsMetricAndBans(t *testing.T) {
addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.10"), Port: 41000}
rl := NewRateLimiter(RateLimitConfig{
MaxFailedAttempts: 1,
FailedAttemptWindow: time.Minute,
BanDuration: time.Hour,
MaxConnectionsPerIP: 10,
})

before := counterMetricValue(t, "duckgres_auth_failures_total")
banned := RecordFailedAuthAttempt(rl, addr)
after := counterMetricValue(t, "duckgres_auth_failures_total")

if !banned {
t.Fatalf("expected failed auth attempt to ban when threshold is 1")
}
if after-before != 1 {
t.Fatalf("expected duckgres_auth_failures_total delta 1, got %.0f", after-before)
}
}

func TestBeginRateLimitedAuthAttemptRejectsBannedAndIncrementsMetric(t *testing.T) {
addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.11"), Port: 41001}
rl := NewRateLimiter(RateLimitConfig{
MaxFailedAttempts: 1,
FailedAttemptWindow: time.Minute,
BanDuration: time.Hour,
MaxConnectionsPerIP: 10,
})
rl.RecordFailedAuth(addr)

before := counterMetricValue(t, "duckgres_rate_limit_rejects_total")
release, reason := BeginRateLimitedAuthAttempt(rl, addr)
release()
after := counterMetricValue(t, "duckgres_rate_limit_rejects_total")

if reason == "" {
t.Fatalf("expected non-empty rejection reason for banned client")
}
if after-before != 1 {
t.Fatalf("expected duckgres_rate_limit_rejects_total delta 1, got %.0f", after-before)
}
}

func TestBeginRateLimitedAuthAttemptRegistersAndReleases(t *testing.T) {
addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.12"), Port: 41002}
rl := NewRateLimiter(RateLimitConfig{
MaxFailedAttempts: 5,
FailedAttemptWindow: time.Minute,
BanDuration: time.Hour,
MaxConnectionsPerIP: 1,
})

release1, reason1 := BeginRateLimitedAuthAttempt(rl, addr)
if reason1 != "" {
t.Fatalf("unexpected first attempt rejection: %q", reason1)
}

release2, reason2 := BeginRateLimitedAuthAttempt(rl, addr)
release2()
if reason2 == "" {
t.Fatalf("expected second concurrent attempt to be rejected")
}

release1()

release3, reason3 := BeginRateLimitedAuthAttempt(rl, addr)
release3()
if reason3 != "" {
t.Fatalf("expected third attempt to succeed after release, got %q", reason3)
}
}
Loading