Skip to content

Commit 1ccd3aa

Browse files
committed
Align Flight auth behavior
1 parent 75a48fc commit 1ccd3aa

File tree

4 files changed

+85
-29
lines changed

4 files changed

+85
-29
lines changed

AGENTS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
- When following a plan file, mark tasks upon completion
99
- When creating new branch from origin/main, do not track origin/main.
1010
- Always run lint before committing.
11-
- Parallelize using subagents when possible.
11+
- Parallelize using subagents when possible.
12+
- Prefer correctness, maintanability, robustness over shortcut implementations

server/flightsqlingress/ingress.go

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,6 @@ func (h *ControlPlaneFlightSQLHandler) sessionFromContext(ctx context.Context) (
211211
remoteAddr = p.Addr
212212
}
213213

214-
releaseRateLimit, rejectReason := server.BeginRateLimitedAuthAttempt(h.rateLimiter, remoteAddr)
215-
defer releaseRateLimit()
216-
if rejectReason != "" {
217-
slog.Warn("Flight auth rejected by rate limit policy.", "remote_addr", remoteAddr, "reason", rejectReason)
218-
return nil, status.Error(codes.ResourceExhausted, "authentication rate limit exceeded")
219-
}
220-
221214
md, ok := metadata.FromIncomingContext(ctx)
222215
if !ok {
223216
server.RecordFailedAuthAttempt(h.rateLimiter, remoteAddr)
@@ -228,26 +221,44 @@ func (h *ControlPlaneFlightSQLHandler) sessionFromContext(ctx context.Context) (
228221
return nil, status.Error(codes.Unavailable, "session store is not configured")
229222
}
230223

231-
username, err := h.authenticateBasicCredentials(md, remoteAddr)
232-
if err != nil {
233-
return nil, err
234-
}
235-
236224
if sessionToken := incomingSessionToken(md); sessionToken != "" {
237225
s, ok := h.sessions.GetByToken(sessionToken)
238226
if !ok {
227+
server.RecordFailedAuthAttempt(h.rateLimiter, remoteAddr)
239228
return nil, status.Error(codes.Unauthenticated, "session not found")
240229
}
241230

242-
if username != s.username {
243-
return nil, status.Error(codes.PermissionDenied, "session token does not match authenticated user")
231+
// When Basic auth is included alongside a bearer session token, enforce
232+
// principal consistency. Token-only auth is allowed after bootstrap.
233+
if hasAuthorizationHeader(md) {
234+
username, err := h.authenticateBasicCredentials(md, remoteAddr)
235+
if err != nil {
236+
return nil, err
237+
}
238+
if username != s.username {
239+
server.RecordFailedAuthAttempt(h.rateLimiter, remoteAddr)
240+
return nil, status.Error(codes.PermissionDenied, "session token does not match authenticated user")
241+
}
244242
}
245243

246244
setSessionTokenMetadata(ctx, sessionToken)
247245
s.touch()
248246
return s, nil
249247
}
250248

249+
// Bootstrap requires Basic auth and is subject to auth rate limiting.
250+
releaseRateLimit, rejectReason := server.BeginRateLimitedAuthAttempt(h.rateLimiter, remoteAddr)
251+
defer releaseRateLimit()
252+
if rejectReason != "" {
253+
slog.Warn("Flight auth rejected by rate limit policy.", "remote_addr", remoteAddr, "reason", rejectReason)
254+
return nil, status.Error(codes.ResourceExhausted, "authentication rate limit exceeded")
255+
}
256+
257+
username, err := h.authenticateBasicCredentials(md, remoteAddr)
258+
if err != nil {
259+
return nil, err
260+
}
261+
251262
s, err := h.sessions.Create(ctx, username)
252263
if err != nil {
253264
return nil, status.Errorf(codes.Unavailable, "create bootstrap session: %v", err)
@@ -258,6 +269,10 @@ func (h *ControlPlaneFlightSQLHandler) sessionFromContext(ctx context.Context) (
258269
return s, nil
259270
}
260271

272+
func hasAuthorizationHeader(md metadata.MD) bool {
273+
return len(md.Get("authorization")) > 0
274+
}
275+
261276
func (h *ControlPlaneFlightSQLHandler) authenticateBasicCredentials(md metadata.MD, remoteAddr net.Addr) (string, error) {
262277
authHeaders := md.Get("authorization")
263278
if len(authHeaders) == 0 {

server/flightsqlingress/ingress_test.go

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,12 @@ func TestFlightAuthSessionKeyDoesNotTrustMetadataClientOverride(t *testing.T) {
172172
}
173173
}
174174

175-
func TestSessionFromContextRejectsServerIssuedSessionTokenWithoutBasicAuth(t *testing.T) {
175+
func TestSessionFromContextAcceptsServerIssuedSessionTokenWithoutBasicAuth(t *testing.T) {
176+
s := newFlightClientSession(1234, "postgres", nil)
177+
s.token = "issued-token"
176178
store := &flightAuthSessionStore{
177179
sessions: map[string]*flightClientSession{
178-
"issued-token": newFlightClientSession(1234, "postgres", nil),
180+
"issued-token": s,
179181
},
180182
}
181183

@@ -185,8 +187,15 @@ func TestSessionFromContextRejectsServerIssuedSessionTokenWithoutBasicAuth(t *te
185187
}
186188

187189
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-duckgres-session", "issued-token"))
188-
if _, err := h.sessionFromContext(ctx); err == nil {
189-
t.Fatalf("expected token-only auth to be rejected")
190+
got, err := h.sessionFromContext(ctx)
191+
if err != nil {
192+
t.Fatalf("expected token-only auth to succeed, got %v", err)
193+
}
194+
if got == nil {
195+
t.Fatalf("expected non-nil session")
196+
}
197+
if got != s {
198+
t.Fatalf("expected existing token session to be reused")
190199
}
191200
}
192201

@@ -253,6 +262,44 @@ func TestSessionFromContextAcceptsServerIssuedSessionTokenWithBasicAuth(t *testi
253262
}
254263
}
255264

265+
func TestSessionFromContextTokenPathDoesNotClearRateLimiterFailures(t *testing.T) {
266+
addr := &net.TCPAddr{IP: net.ParseIP("203.0.113.47"), Port: 30004}
267+
rateLimiter := server.NewRateLimiter(server.RateLimitConfig{
268+
MaxFailedAttempts: 2,
269+
FailedAttemptWindow: time.Minute,
270+
BanDuration: time.Hour,
271+
MaxConnectionsPerIP: 100,
272+
})
273+
rateLimiter.RecordFailedAuth(addr)
274+
275+
s := newFlightClientSession(1234, "postgres", nil)
276+
s.token = "issued-token"
277+
store := &flightAuthSessionStore{
278+
sessions: map[string]*flightClientSession{
279+
"issued-token": s,
280+
},
281+
}
282+
h, err := NewControlPlaneFlightSQLHandler(store, map[string]string{"postgres": "postgres"})
283+
if err != nil {
284+
t.Fatalf("failed to construct handler: %v", err)
285+
}
286+
h.rateLimiter = rateLimiter
287+
288+
base := peer.NewContext(context.Background(), &peer.Peer{Addr: addr})
289+
ctx := metadata.NewIncomingContext(base, metadata.Pairs("x-duckgres-session", "issued-token"))
290+
if _, err := h.sessionFromContext(ctx); err != nil {
291+
t.Fatalf("token-only auth failed: %v", err)
292+
}
293+
294+
_, err = h.sessionFromContext(authContextForPeer(addr, "postgres", "wrong"))
295+
if status.Code(err) != codes.Unauthenticated {
296+
t.Fatalf("expected unauthenticated error for bad password, got %v", err)
297+
}
298+
if !rateLimiter.IsBanned(addr) {
299+
t.Fatalf("expected prior failure + new failure to ban; token-only path should not clear failures")
300+
}
301+
}
302+
256303
func TestSessionFromContextWithoutTokenCreatesDistinctSessions(t *testing.T) {
257304
var createCalls atomic.Int32
258305
store := &flightAuthSessionStore{

tests/controlplane/flight_ingress_test.go

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@ func flightAuthContext(username, password string) context.Context {
2828
return metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorization", "Basic "+token))
2929
}
3030

31-
func basicAuthHeader(username, password string) string {
32-
token := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
33-
return "Basic " + token
34-
}
35-
3631
func newFlightClient(t *testing.T, port int) *flightsql.Client {
3732
t.Helper()
3833
addr := fmt.Sprintf("127.0.0.1:%d", port)
@@ -150,12 +145,10 @@ func TestFlightIngressIncludeSchemaLowWorkerRegression(t *testing.T) {
150145
errCh <- fmt.Errorf("worker %d bootstrap missing x-duckgres-session header", workerID)
151146
return
152147
}
153-
authHeader := basicAuthHeader("testuser", "testpass")
154148
token := strings.TrimSpace(sessionTokens[0])
155149

156150
for i := 0; i < iterationsPerGoroutine; i++ {
157151
iterBaseCtx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs(
158-
"authorization", authHeader,
159152
"x-duckgres-session", token,
160153
))
161154
iterCtx, iterCancel := context.WithTimeout(iterBaseCtx, 20*time.Second)
@@ -177,7 +170,7 @@ func TestFlightIngressIncludeSchemaLowWorkerRegression(t *testing.T) {
177170
}
178171
}
179172

180-
func TestFlightIngressServerIssuedSessionTokenRequiresBasicAuth(t *testing.T) {
173+
func TestFlightIngressServerIssuedSessionTokenAllowsTokenOnlyAuth(t *testing.T) {
181174
h := startControlPlane(t, cpOpts{
182175
flightPort: freePort(t),
183176
maxWorkers: 1,
@@ -206,8 +199,8 @@ func TestFlightIngressServerIssuedSessionTokenRequiresBasicAuth(t *testing.T) {
206199
ctx2, cancel2 := context.WithTimeout(tokenCtx, 20*time.Second)
207200
defer cancel2()
208201

209-
if _, err := client2.GetTables(ctx2, &flightsql.GetTablesOpts{}); err == nil {
210-
t.Fatalf("expected token-only GetTables to fail without basic auth")
202+
if _, err := client2.GetTables(ctx2, &flightsql.GetTablesOpts{}); err != nil {
203+
t.Fatalf("expected token-only GetTables to succeed, got %v", err)
211204
}
212205

213206
tokenAndAuthCtx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs(

0 commit comments

Comments
 (0)