diff --git a/AGENTS.md b/AGENTS.md index 99aefae..2e43d16 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -8,4 +8,5 @@ - When following a plan file, mark tasks upon completion - When creating new branch from origin/main, do not track origin/main. - Always run lint before committing. -- Parallelize using subagents when possible. \ No newline at end of file +- Parallelize using subagents when possible. +- Prefer correctness, maintanability, robustness over shortcut implementations diff --git a/controlplane/control.go b/controlplane/control.go index 06058b6..cfd20f7 100644 --- a/controlplane/control.go +++ b/controlplane/control.go @@ -170,7 +170,7 @@ func RunControlPlane(cfg ControlPlaneConfig) { // It is intentionally started after pre-warm to avoid concurrent worker // creation races between pre-warm and first external Flight requests. if cfg.FlightPort > 0 { - flightIngress, err := NewFlightIngress(cfg.Host, cfg.FlightPort, tlsCfg, cfg.Users, sessions, FlightIngressConfig{ + flightIngress, err := NewFlightIngress(cfg.Host, cfg.FlightPort, tlsCfg, cfg.Users, sessions, cp.rateLimiter, FlightIngressConfig{ SessionIdleTTL: cfg.FlightSessionIdleTTL, SessionReapTick: cfg.FlightSessionReapInterval, HandleIdleTTL: cfg.FlightHandleIdleTTL, @@ -343,19 +343,13 @@ func (cp *ControlPlane) acceptLoop() { func (cp *ControlPlane) handleConnection(conn net.Conn) { remoteAddr := conn.RemoteAddr() - // Rate limiting - if msg := cp.rateLimiter.CheckConnection(remoteAddr); msg != "" { + releaseRateLimit, msg := server.BeginRateLimitedAuthAttempt(cp.rateLimiter, remoteAddr) + if msg != "" { slog.Warn("Connection rejected.", "remote_addr", remoteAddr, "reason", msg) _ = conn.Close() return } - - if !cp.rateLimiter.RegisterConnection(remoteAddr) { - slog.Warn("Connection rejected: rate limit.", "remote_addr", remoteAddr) - _ = conn.Close() - return - } - defer cp.rateLimiter.UnregisterConnection(remoteAddr) + defer releaseRateLimit() // Read startup message to determine SSL vs cancel params, err := readStartupFromRaw(conn) @@ -376,6 +370,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) { // Require SSL if !params.sslRequest { slog.Warn("Connection rejected: SSL required.", "remote_addr", remoteAddr) + server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr) _ = conn.Close() return } @@ -423,20 +418,12 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) { database := startupParams["database"] if username == "" { + server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr) _ = server.WriteErrorResponse(writer, "FATAL", "28000", "no user specified") _ = writer.Flush() return } - // Look up expected password for this user - expectedPassword, ok := cp.cfg.Users[username] - if !ok { - slog.Warn("Unknown user.", "user", username, "remote_addr", remoteAddr) - _ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed") - _ = writer.Flush() - return - } - // Request password if err := server.WriteAuthCleartextPassword(writer); err != nil { slog.Error("Failed to request password.", "remote_addr", remoteAddr, "error", err) @@ -455,15 +442,19 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) { } if msgType != 'p' { + server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr) _ = server.WriteErrorResponse(writer, "FATAL", "28000", "expected password message") _ = writer.Flush() return } password := string(bytes.TrimRight(body, "\x00")) - if password != expectedPassword { + if !server.ValidateUserPassword(cp.cfg.Users, username, password) { slog.Warn("Authentication failed.", "user", username, "remote_addr", remoteAddr) - cp.rateLimiter.RecordFailedAuth(remoteAddr) + banned := server.RecordFailedAuthAttempt(cp.rateLimiter, remoteAddr) + if banned { + slog.Warn("IP banned after too many failed auth attempts.", "remote_addr", remoteAddr) + } _ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed") _ = writer.Flush() return @@ -475,7 +466,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) { return } - cp.rateLimiter.RecordSuccessfulAuth(remoteAddr) + server.RecordSuccessfulAuthAttempt(cp.rateLimiter, remoteAddr) slog.Info("User authenticated.", "user", username, "remote_addr", remoteAddr) // Create session on a worker @@ -740,7 +731,7 @@ func (cp *ControlPlane) recoverFlightIngressAfterFailedReload() { return } - flightIngress, err := NewFlightIngress(cp.cfg.Host, cp.cfg.FlightPort, cp.tlsConfig, cp.cfg.Users, cp.sessions, FlightIngressConfig{ + flightIngress, err := NewFlightIngress(cp.cfg.Host, cp.cfg.FlightPort, cp.tlsConfig, cp.cfg.Users, cp.sessions, cp.rateLimiter, FlightIngressConfig{ SessionIdleTTL: cp.cfg.FlightSessionIdleTTL, SessionReapTick: cp.cfg.FlightSessionReapInterval, HandleIdleTTL: cp.cfg.FlightHandleIdleTTL, diff --git a/controlplane/flight_ingress.go b/controlplane/flight_ingress.go index 0921854..74bd733 100644 --- a/controlplane/flight_ingress.go +++ b/controlplane/flight_ingress.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "errors" + "github.com/posthog/duckgres/server" "github.com/posthog/duckgres/server/flightsqlingress" ) @@ -12,11 +13,12 @@ type FlightIngressConfig = flightsqlingress.Config type FlightIngress = flightsqlingress.FlightIngress // NewFlightIngress creates a control-plane Flight SQL ingress listener. -func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[string]string, sm *SessionManager, cfg FlightIngressConfig) (*FlightIngress, error) { +func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[string]string, sm *SessionManager, rateLimiter *server.RateLimiter, cfg FlightIngressConfig) (*FlightIngress, error) { return flightsqlingress.NewFlightIngress(host, port, tlsConfig, users, sm, cfg, flightsqlingress.Options{ IsMaxWorkersError: func(err error) bool { return errors.Is(err, ErrMaxWorkersReached) }, + RateLimiter: rateLimiter, Hooks: flightsqlingress.Hooks{ OnSessionCountChanged: observeFlightAuthSessions, OnSessionsReaped: observeFlightSessionsReaped, diff --git a/controlplane/flight_ingress_adapter_test.go b/controlplane/flight_ingress_adapter_test.go index a078f55..2187453 100644 --- a/controlplane/flight_ingress_adapter_test.go +++ b/controlplane/flight_ingress_adapter_test.go @@ -3,7 +3,7 @@ package controlplane import "testing" func TestNewFlightIngressAdapterValidation(t *testing.T) { - _, err := NewFlightIngress("127.0.0.1", 0, nil, map[string]string{}, nil, FlightIngressConfig{}) + _, err := NewFlightIngress("127.0.0.1", 0, nil, map[string]string{}, nil, nil, FlightIngressConfig{}) if err == nil { t.Fatalf("expected validation error for invalid port") } diff --git a/server/auth_policy.go b/server/auth_policy.go new file mode 100644 index 0000000..9c56f16 --- /dev/null +++ b/server/auth_policy.go @@ -0,0 +1,89 @@ +package server + +import ( + "crypto/subtle" + "net" +) + +const invalidPasswordSentinel = "__duckgres_invalid_password_sentinel__" + +// BeginRateLimitedAuthAttempt enforces rate-limit policy before an auth attempt. +// The returned release function must be called once the attempt is complete. +func BeginRateLimitedAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) (release func(), rejectReason string) { + release = func() {} + if rateLimiter == nil { + return release, "" + } + + if msg := rateLimiter.CheckConnection(remoteAddr); msg != "" { + rateLimitRejectsCounter.Inc() + return release, msg + } + if !rateLimiter.RegisterConnection(remoteAddr) { + rateLimitRejectsCounter.Inc() + if msg := rateLimiter.CheckConnection(remoteAddr); msg != "" { + return release, msg + } + return release, "too many connections from your IP address" + } + + return func() { + rateLimiter.UnregisterConnection(remoteAddr) + }, "" +} + +// RecordFailedAuthAttempt records auth telemetry and updates rate-limit state. +// Returns true when this failure causes the source IP to be banned. +func RecordFailedAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) bool { + authFailuresCounter.Inc() + if rateLimiter == nil { + return false + } + return rateLimiter.RecordFailedAuth(remoteAddr) +} + +// RecordSuccessfulAuthAttempt clears failure tracking after successful auth. +func RecordSuccessfulAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) { + if rateLimiter == nil { + return + } + rateLimiter.RecordSuccessfulAuth(remoteAddr) +} + +// ValidateUserPassword validates username/password without leaking user existence +// via credential-compare timing differences. +func ValidateUserPassword(users map[string]string, username, password string) bool { + expectedPassword, userFound := users[username] + if !userFound { + expectedPassword = invalidPasswordSentinel + } + + passwordMatches := constantTimeStringEqual(password, expectedPassword) + return userFound && passwordMatches +} + +func constantTimeStringEqual(a, b string) bool { + ab := []byte(a) + bb := []byte(b) + + maxLen := len(ab) + if len(bb) > maxLen { + maxLen = len(bb) + } + + var diff byte + for i := 0; i < maxLen; i++ { + var av byte + var bv byte + if i < len(ab) { + av = ab[i] + } + if i < len(bb) { + bv = bb[i] + } + diff |= av ^ bv + } + + lengthsEqual := subtle.ConstantTimeEq(int32(len(ab)), int32(len(bb))) == 1 + return lengthsEqual && diff == 0 +} diff --git a/server/auth_policy_test.go b/server/auth_policy_test.go new file mode 100644 index 0000000..5739a58 --- /dev/null +++ b/server/auth_policy_test.go @@ -0,0 +1,120 @@ +package server + +import ( + "net" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" +) + +func counterMetricValue(t *testing.T, metricName string) float64 { + t.Helper() + families, err := prometheus.DefaultGatherer.Gather() + if err != nil { + t.Fatalf("failed to gather metrics: %v", err) + } + for _, fam := range families { + if fam.GetName() != metricName { + continue + } + if fam.GetType() != dto.MetricType_COUNTER { + t.Fatalf("metric %q is not a counter", metricName) + } + var total float64 + for _, metric := range fam.GetMetric() { + total += metric.GetCounter().GetValue() + } + return total + } + t.Fatalf("metric %q not found", metricName) + return 0 +} + +func TestValidateUserPassword(t *testing.T) { + users := map[string]string{"postgres": "postgres"} + + if !ValidateUserPassword(users, "postgres", "postgres") { + t.Fatalf("expected valid credentials to pass") + } + if ValidateUserPassword(users, "postgres", "wrong") { + t.Fatalf("expected wrong password to fail") + } + if ValidateUserPassword(users, "unknown", "postgres") { + t.Fatalf("expected unknown user to fail") + } +} + +func TestRecordFailedAuthAttemptIncrementsMetricAndBans(t *testing.T) { + addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.10"), Port: 41000} + rl := NewRateLimiter(RateLimitConfig{ + MaxFailedAttempts: 1, + FailedAttemptWindow: time.Minute, + BanDuration: time.Hour, + MaxConnectionsPerIP: 10, + }) + + before := counterMetricValue(t, "duckgres_auth_failures_total") + banned := RecordFailedAuthAttempt(rl, addr) + after := counterMetricValue(t, "duckgres_auth_failures_total") + + if !banned { + t.Fatalf("expected failed auth attempt to ban when threshold is 1") + } + if after-before != 1 { + t.Fatalf("expected duckgres_auth_failures_total delta 1, got %.0f", after-before) + } +} + +func TestBeginRateLimitedAuthAttemptRejectsBannedAndIncrementsMetric(t *testing.T) { + addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.11"), Port: 41001} + rl := NewRateLimiter(RateLimitConfig{ + MaxFailedAttempts: 1, + FailedAttemptWindow: time.Minute, + BanDuration: time.Hour, + MaxConnectionsPerIP: 10, + }) + rl.RecordFailedAuth(addr) + + before := counterMetricValue(t, "duckgres_rate_limit_rejects_total") + release, reason := BeginRateLimitedAuthAttempt(rl, addr) + release() + after := counterMetricValue(t, "duckgres_rate_limit_rejects_total") + + if reason == "" { + t.Fatalf("expected non-empty rejection reason for banned client") + } + if after-before != 1 { + t.Fatalf("expected duckgres_rate_limit_rejects_total delta 1, got %.0f", after-before) + } +} + +func TestBeginRateLimitedAuthAttemptRegistersAndReleases(t *testing.T) { + addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.12"), Port: 41002} + rl := NewRateLimiter(RateLimitConfig{ + MaxFailedAttempts: 5, + FailedAttemptWindow: time.Minute, + BanDuration: time.Hour, + MaxConnectionsPerIP: 1, + }) + + release1, reason1 := BeginRateLimitedAuthAttempt(rl, addr) + if reason1 != "" { + t.Fatalf("unexpected first attempt rejection: %q", reason1) + } + + release2, reason2 := BeginRateLimitedAuthAttempt(rl, addr) + release2() + if reason2 == "" { + t.Fatalf("expected second concurrent attempt to be rejected") + } + + release1() + + release3, reason3 := BeginRateLimitedAuthAttempt(rl, addr) + release3() + if reason3 != "" { + t.Fatalf("expected third attempt to succeed after release, got %q", reason3) + } +} diff --git a/server/flightsqlingress/ingress.go b/server/flightsqlingress/ingress.go index a8fd320..37ad877 100644 --- a/server/flightsqlingress/ingress.go +++ b/server/flightsqlingress/ingress.go @@ -3,7 +3,6 @@ package flightsqlingress import ( "context" "crypto/rand" - "crypto/subtle" "crypto/tls" "database/sql" "encoding/base64" @@ -70,6 +69,7 @@ type Hooks struct { type Options struct { IsMaxWorkersError func(error) bool Hooks Hooks + RateLimiter *server.RateLimiter } // FlightIngress serves Arrow Flight SQL on the control plane with Basic auth. @@ -123,6 +123,7 @@ func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[st _ = ln.Close() return nil, err } + handler.rateLimiter = opts.RateLimiter grpcOpts := []grpc.ServerOption{ grpc.MaxRecvMsgSize(server.MaxGRPCMessageSize), @@ -180,9 +181,10 @@ func (fi *FlightIngress) Shutdown() { // ControlPlaneFlightSQLHandler implements Flight SQL over control-plane sessions. type ControlPlaneFlightSQLHandler struct { flightsql.BaseServer - users map[string]string - sessions *flightAuthSessionStore - alloc memory.Allocator + users map[string]string + sessions *flightAuthSessionStore + rateLimiter *server.RateLimiter + alloc memory.Allocator } func NewControlPlaneFlightSQLHandler(sessions *flightAuthSessionStore, users map[string]string) (*ControlPlaneFlightSQLHandler, error) { @@ -204,8 +206,14 @@ func NewControlPlaneFlightSQLHandler(sessions *flightAuthSessionStore, users map } func (h *ControlPlaneFlightSQLHandler) sessionFromContext(ctx context.Context) (*flightClientSession, error) { + var remoteAddr net.Addr + if p, ok := peer.FromContext(ctx); ok && p != nil { + remoteAddr = p.Addr + } + md, ok := metadata.FromIncomingContext(ctx) if !ok { + server.RecordFailedAuthAttempt(h.rateLimiter, remoteAddr) return nil, status.Error(codes.Unauthenticated, "missing metadata") } @@ -213,19 +221,24 @@ func (h *ControlPlaneFlightSQLHandler) sessionFromContext(ctx context.Context) ( return nil, status.Error(codes.Unavailable, "session store is not configured") } - username, err := h.authenticateBasicCredentials(md) - if err != nil { - return nil, err - } - if sessionToken := incomingSessionToken(md); sessionToken != "" { s, ok := h.sessions.GetByToken(sessionToken) if !ok { + server.RecordFailedAuthAttempt(h.rateLimiter, remoteAddr) return nil, status.Error(codes.Unauthenticated, "session not found") } - if username != s.username { - return nil, status.Error(codes.PermissionDenied, "session token does not match authenticated user") + // When Basic auth is included alongside a bearer session token, enforce + // principal consistency. Token-only auth is allowed after bootstrap. + if hasAuthorizationHeader(md) { + username, err := h.authenticateBasicCredentials(md, remoteAddr) + if err != nil { + return nil, err + } + if username != s.username { + server.RecordFailedAuthAttempt(h.rateLimiter, remoteAddr) + return nil, status.Error(codes.PermissionDenied, "session token does not match authenticated user") + } } setSessionTokenMetadata(ctx, sessionToken) @@ -233,6 +246,19 @@ func (h *ControlPlaneFlightSQLHandler) sessionFromContext(ctx context.Context) ( return s, nil } + // Bootstrap requires Basic auth and is subject to auth rate limiting. + releaseRateLimit, rejectReason := server.BeginRateLimitedAuthAttempt(h.rateLimiter, remoteAddr) + defer releaseRateLimit() + if rejectReason != "" { + slog.Warn("Flight auth rejected by rate limit policy.", "remote_addr", remoteAddr, "reason", rejectReason) + return nil, status.Error(codes.ResourceExhausted, "authentication rate limit exceeded") + } + + username, err := h.authenticateBasicCredentials(md, remoteAddr) + if err != nil { + return nil, err + } + s, err := h.sessions.Create(ctx, username) if err != nil { return nil, status.Errorf(codes.Unavailable, "create bootstrap session: %v", err) @@ -243,25 +269,31 @@ func (h *ControlPlaneFlightSQLHandler) sessionFromContext(ctx context.Context) ( return s, nil } -func (h *ControlPlaneFlightSQLHandler) authenticateBasicCredentials(md metadata.MD) (string, error) { +func hasAuthorizationHeader(md metadata.MD) bool { + return len(md.Get("authorization")) > 0 +} + +func (h *ControlPlaneFlightSQLHandler) authenticateBasicCredentials(md metadata.MD, remoteAddr net.Addr) (string, error) { authHeaders := md.Get("authorization") if len(authHeaders) == 0 { + server.RecordFailedAuthAttempt(h.rateLimiter, remoteAddr) return "", status.Error(codes.Unauthenticated, "missing authorization header") } username, password, err := parseBasicCredentials(authHeaders[0]) if err != nil { + server.RecordFailedAuthAttempt(h.rateLimiter, remoteAddr) return "", status.Error(codes.Unauthenticated, err.Error()) } - expected, userFound := h.users[username] - if !userFound { - expected = "__invalid__" - } - passMatch := subtle.ConstantTimeCompare([]byte(password), []byte(expected)) == 1 - if !userFound || !passMatch { + if !server.ValidateUserPassword(h.users, username, password) { + banned := server.RecordFailedAuthAttempt(h.rateLimiter, remoteAddr) + if banned { + slog.Warn("Flight client IP banned after auth failures.", "remote_addr", remoteAddr) + } return "", status.Error(codes.Unauthenticated, "invalid credentials") } + server.RecordSuccessfulAuthAttempt(h.rateLimiter, remoteAddr) return username, nil } diff --git a/server/flightsqlingress/ingress_test.go b/server/flightsqlingress/ingress_test.go index 42aac86..6a2a57a 100644 --- a/server/flightsqlingress/ingress_test.go +++ b/server/flightsqlingress/ingress_test.go @@ -13,8 +13,12 @@ import ( "time" "github.com/posthog/duckgres/server" + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" ) var errTestMaxWorkersReached = errors.New("max workers reached") @@ -28,6 +32,57 @@ func (r testExecResult) RowsAffected() (int64, error) { return r.affected, r.err } +func metricCounterValue(t *testing.T, metricName string) float64 { + t.Helper() + families, err := prometheus.DefaultGatherer.Gather() + if err != nil { + t.Fatalf("failed to gather metrics: %v", err) + } + for _, fam := range families { + if fam.GetName() != metricName { + continue + } + if fam.GetType() != dto.MetricType_COUNTER { + t.Fatalf("metric %q is not a counter", metricName) + } + var total float64 + for _, metric := range fam.GetMetric() { + total += metric.GetCounter().GetValue() + } + return total + } + t.Fatalf("metric %q not found", metricName) + return 0 +} + +func authContextForPeer(addr net.Addr, username, password string) context.Context { + token := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + base := peer.NewContext(context.Background(), &peer.Peer{Addr: addr}) + return metadata.NewIncomingContext(base, metadata.Pairs("authorization", "Basic "+token)) +} + +func testFlightHandlerWithStoreAndRateLimiter(t *testing.T, users map[string]string, rateLimiter *server.RateLimiter) *ControlPlaneFlightSQLHandler { + t.Helper() + store := &flightAuthSessionStore{ + idleTTL: time.Minute, + reapInterval: time.Hour, + handleIdleTTL: time.Minute, + sessions: make(map[string]*flightClientSession), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + createSessionFn: func(context.Context, string) (int32, *server.FlightExecutor, error) { + return 1234, nil, nil + }, + destroySessionFn: func(int32) {}, + } + h, err := NewControlPlaneFlightSQLHandler(store, users) + if err != nil { + t.Fatalf("NewControlPlaneFlightSQLHandler returned error: %v", err) + } + h.rateLimiter = rateLimiter + return h +} + func TestParseBasicCredentials(t *testing.T) { token := base64.StdEncoding.EncodeToString([]byte("postgres:postgres")) user, pass, err := parseBasicCredentials("Basic " + token) @@ -117,10 +172,12 @@ func TestFlightAuthSessionKeyDoesNotTrustMetadataClientOverride(t *testing.T) { } } -func TestSessionFromContextRejectsServerIssuedSessionTokenWithoutBasicAuth(t *testing.T) { +func TestSessionFromContextAcceptsServerIssuedSessionTokenWithoutBasicAuth(t *testing.T) { + s := newFlightClientSession(1234, "postgres", nil) + s.token = "issued-token" store := &flightAuthSessionStore{ sessions: map[string]*flightClientSession{ - "issued-token": newFlightClientSession(1234, "postgres", nil), + "issued-token": s, }, } @@ -130,8 +187,15 @@ func TestSessionFromContextRejectsServerIssuedSessionTokenWithoutBasicAuth(t *te } ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-duckgres-session", "issued-token")) - if _, err := h.sessionFromContext(ctx); err == nil { - t.Fatalf("expected token-only auth to be rejected") + got, err := h.sessionFromContext(ctx) + if err != nil { + t.Fatalf("expected token-only auth to succeed, got %v", err) + } + if got == nil { + t.Fatalf("expected non-nil session") + } + if got != s { + t.Fatalf("expected existing token session to be reused") } } @@ -198,6 +262,44 @@ func TestSessionFromContextAcceptsServerIssuedSessionTokenWithBasicAuth(t *testi } } +func TestSessionFromContextTokenPathDoesNotClearRateLimiterFailures(t *testing.T) { + addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.47"), Port: 30004} + rateLimiter := server.NewRateLimiter(server.RateLimitConfig{ + MaxFailedAttempts: 2, + FailedAttemptWindow: time.Minute, + BanDuration: time.Hour, + MaxConnectionsPerIP: 100, + }) + rateLimiter.RecordFailedAuth(addr) + + s := newFlightClientSession(1234, "postgres", nil) + s.token = "issued-token" + store := &flightAuthSessionStore{ + sessions: map[string]*flightClientSession{ + "issued-token": s, + }, + } + h, err := NewControlPlaneFlightSQLHandler(store, map[string]string{"postgres": "postgres"}) + if err != nil { + t.Fatalf("failed to construct handler: %v", err) + } + h.rateLimiter = rateLimiter + + base := peer.NewContext(context.Background(), &peer.Peer{Addr: addr}) + ctx := metadata.NewIncomingContext(base, metadata.Pairs("x-duckgres-session", "issued-token")) + if _, err := h.sessionFromContext(ctx); err != nil { + t.Fatalf("token-only auth failed: %v", err) + } + + _, err = h.sessionFromContext(authContextForPeer(addr, "postgres", "wrong")) + if status.Code(err) != codes.Unauthenticated { + t.Fatalf("expected unauthenticated error for bad password, got %v", err) + } + if !rateLimiter.IsBanned(addr) { + t.Fatalf("expected prior failure + new failure to ban; token-only path should not clear failures") + } +} + func TestSessionFromContextWithoutTokenCreatesDistinctSessions(t *testing.T) { var createCalls atomic.Int32 store := &flightAuthSessionStore{ @@ -367,6 +469,79 @@ func TestNewControlPlaneFlightSQLHandlerReturnsError(t *testing.T) { } } +func TestSessionFromContextInvalidCredentialsIncrementsAuthFailureMetric(t *testing.T) { + h := testFlightHandlerWithStoreAndRateLimiter(t, map[string]string{"postgres": "postgres"}, nil) + ctx := authContextForPeer(&net.TCPAddr{IP: net.ParseIP("203.0.113.44"), Port: 30001}, "postgres", "wrong") + + before := metricCounterValue(t, "duckgres_auth_failures_total") + _, err := h.sessionFromContext(ctx) + after := metricCounterValue(t, "duckgres_auth_failures_total") + + if status.Code(err) != codes.Unauthenticated { + t.Fatalf("expected unauthenticated error, got %v", err) + } + if after-before != 1 { + t.Fatalf("expected duckgres_auth_failures_total delta 1, got %.0f", after-before) + } +} + +func TestSessionFromContextRateLimitedRejectsAndIncrementsMetric(t *testing.T) { + addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.45"), Port: 30002} + rateLimiter := server.NewRateLimiter(server.RateLimitConfig{ + MaxFailedAttempts: 1, + FailedAttemptWindow: time.Minute, + BanDuration: time.Hour, + MaxConnectionsPerIP: 100, + }) + rateLimiter.RecordFailedAuth(addr) + + h := testFlightHandlerWithStoreAndRateLimiter(t, map[string]string{"postgres": "postgres"}, rateLimiter) + ctx := authContextForPeer(addr, "postgres", "postgres") + + before := metricCounterValue(t, "duckgres_rate_limit_rejects_total") + _, err := h.sessionFromContext(ctx) + after := metricCounterValue(t, "duckgres_rate_limit_rejects_total") + + if status.Code(err) != codes.ResourceExhausted { + t.Fatalf("expected resource exhausted error, got %v", err) + } + if after-before != 1 { + t.Fatalf("expected duckgres_rate_limit_rejects_total delta 1, got %.0f", after-before) + } +} + +func TestSessionFromContextFailedAndSuccessfulAuthUpdateRateLimiter(t *testing.T) { + addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.46"), Port: 30003} + rateLimiter := server.NewRateLimiter(server.RateLimitConfig{ + MaxFailedAttempts: 2, + FailedAttemptWindow: time.Minute, + BanDuration: time.Hour, + MaxConnectionsPerIP: 100, + }) + h := testFlightHandlerWithStoreAndRateLimiter(t, map[string]string{"postgres": "postgres"}, rateLimiter) + + _, err := h.sessionFromContext(authContextForPeer(addr, "postgres", "wrong")) + if status.Code(err) != codes.Unauthenticated { + t.Fatalf("expected unauthenticated error for bad password, got %v", err) + } + + s, err := h.sessionFromContext(authContextForPeer(addr, "postgres", "postgres")) + if err != nil { + t.Fatalf("expected successful auth, got %v", err) + } + if s == nil { + t.Fatalf("expected non-nil session") + } + + _, err = h.sessionFromContext(authContextForPeer(addr, "postgres", "wrong")) + if status.Code(err) != codes.Unauthenticated { + t.Fatalf("expected unauthenticated error for bad password, got %v", err) + } + if rateLimiter.IsBanned(addr) { + t.Fatalf("expected successful auth to clear prior failures before next bad password") + } +} + func TestFlightAuthSessionStoreRetriesAfterForcedReapOnMaxWorkers(t *testing.T) { stale := newFlightClientSession(9001, "postgres", nil) stale.lastUsed.Store(time.Now().Add(-1 * time.Hour).UnixNano()) diff --git a/tests/controlplane/flight_ingress_test.go b/tests/controlplane/flight_ingress_test.go index 1148c6b..c87a05a 100644 --- a/tests/controlplane/flight_ingress_test.go +++ b/tests/controlplane/flight_ingress_test.go @@ -28,11 +28,6 @@ func flightAuthContext(username, password string) context.Context { return metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorization", "Basic "+token)) } -func basicAuthHeader(username, password string) string { - token := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) - return "Basic " + token -} - func newFlightClient(t *testing.T, port int) *flightsql.Client { t.Helper() addr := fmt.Sprintf("127.0.0.1:%d", port) @@ -150,12 +145,10 @@ func TestFlightIngressIncludeSchemaLowWorkerRegression(t *testing.T) { errCh <- fmt.Errorf("worker %d bootstrap missing x-duckgres-session header", workerID) return } - authHeader := basicAuthHeader("testuser", "testpass") token := strings.TrimSpace(sessionTokens[0]) for i := 0; i < iterationsPerGoroutine; i++ { iterBaseCtx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs( - "authorization", authHeader, "x-duckgres-session", token, )) iterCtx, iterCancel := context.WithTimeout(iterBaseCtx, 20*time.Second) @@ -177,7 +170,7 @@ func TestFlightIngressIncludeSchemaLowWorkerRegression(t *testing.T) { } } -func TestFlightIngressServerIssuedSessionTokenRequiresBasicAuth(t *testing.T) { +func TestFlightIngressServerIssuedSessionTokenAllowsTokenOnlyAuth(t *testing.T) { h := startControlPlane(t, cpOpts{ flightPort: freePort(t), maxWorkers: 1, @@ -206,8 +199,8 @@ func TestFlightIngressServerIssuedSessionTokenRequiresBasicAuth(t *testing.T) { ctx2, cancel2 := context.WithTimeout(tokenCtx, 20*time.Second) defer cancel2() - if _, err := client2.GetTables(ctx2, &flightsql.GetTablesOpts{}); err == nil { - t.Fatalf("expected token-only GetTables to fail without basic auth") + if _, err := client2.GetTables(ctx2, &flightsql.GetTablesOpts{}); err != nil { + t.Fatalf("expected token-only GetTables to succeed, got %v", err) } tokenAndAuthCtx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs(