diff --git a/README.md b/README.md index b9d36a3..20246f9 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,7 @@ flight_port: 8815 flight_session_idle_ttl: "10m" flight_session_reap_interval: "1m" flight_handle_idle_ttl: "15m" +flight_session_token_ttl: "1h" data_dir: "./data" tls: @@ -176,6 +177,7 @@ Run with config file: | `DUCKGRES_FLIGHT_SESSION_IDLE_TTL` | Flight auth session idle TTL | `10m` | | `DUCKGRES_FLIGHT_SESSION_REAP_INTERVAL` | Flight auth session reap interval | `1m` | | `DUCKGRES_FLIGHT_HANDLE_IDLE_TTL` | Flight prepared/query handle idle TTL | `15m` | +| `DUCKGRES_FLIGHT_SESSION_TOKEN_TTL` | Flight issued session token absolute TTL | `1h` | | `DUCKGRES_DATA_DIR` | Directory for DuckDB files | `./data` | | `DUCKGRES_CERT` | TLS certificate file | `./certs/server.crt` | | `DUCKGRES_KEY` | TLS private key file | `./certs/server.key` | @@ -221,6 +223,7 @@ Options: -flight-session-idle-ttl string Flight auth session idle TTL (e.g., '10m') -flight-session-reap-interval string Flight auth session reap interval (e.g., '1m') -flight-handle-idle-ttl string Flight prepared/query handle idle TTL (e.g., '15m') + -flight-session-token-ttl string Flight issued session token absolute TTL (e.g., '1h') -data-dir string Directory for DuckDB files -cert string TLS certificate file -key string TLS private key file diff --git a/config_resolution.go b/config_resolution.go index 5d4d474..20ec43c 100644 --- a/config_resolution.go +++ b/config_resolution.go @@ -16,6 +16,7 @@ type configCLIInputs struct { FlightSessionIdleTTL string FlightSessionReapInterval string FlightHandleIdleTTL string + FlightSessionTokenTTL string DataDir string CertFile string KeyFile string @@ -44,6 +45,7 @@ func defaultServerConfig() server.Config { FlightSessionIdleTTL: 10 * time.Minute, FlightSessionReapInterval: 1 * time.Minute, FlightHandleIdleTTL: 15 * time.Minute, + FlightSessionTokenTTL: 1 * time.Hour, DataDir: "./data", TLSCertFile: "./certs/server.crt", TLSKeyFile: "./certs/server.key", @@ -98,6 +100,13 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun warn("Invalid flight_handle_idle_ttl duration: " + err.Error()) } } + if fileCfg.FlightSessionTokenTTL != "" { + if d, err := time.ParseDuration(fileCfg.FlightSessionTokenTTL); err == nil { + cfg.FlightSessionTokenTTL = d + } else { + warn("Invalid flight_session_token_ttl duration: " + err.Error()) + } + } if fileCfg.DataDir != "" { cfg.DataDir = fileCfg.DataDir } @@ -251,6 +260,13 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun warn("Invalid DUCKGRES_FLIGHT_HANDLE_IDLE_TTL duration: " + err.Error()) } } + if v := getenv("DUCKGRES_FLIGHT_SESSION_TOKEN_TTL"); v != "" { + if d, err := time.ParseDuration(v); err == nil { + cfg.FlightSessionTokenTTL = d + } else { + warn("Invalid DUCKGRES_FLIGHT_SESSION_TOKEN_TTL duration: " + err.Error()) + } + } if v := getenv("DUCKGRES_DATA_DIR"); v != "" { cfg.DataDir = v } @@ -385,6 +401,13 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun warn("Invalid --flight-handle-idle-ttl duration: " + err.Error()) } } + if cli.Set["flight-session-token-ttl"] { + if d, err := time.ParseDuration(cli.FlightSessionTokenTTL); err == nil { + cfg.FlightSessionTokenTTL = d + } else { + warn("Invalid --flight-session-token-ttl duration: " + err.Error()) + } + } if cli.Set["data-dir"] { cfg.DataDir = cli.DataDir } diff --git a/controlplane/control.go b/controlplane/control.go index df29625..6e41464 100644 --- a/controlplane/control.go +++ b/controlplane/control.go @@ -174,6 +174,7 @@ func RunControlPlane(cfg ControlPlaneConfig) { SessionIdleTTL: cfg.FlightSessionIdleTTL, SessionReapTick: cfg.FlightSessionReapInterval, HandleIdleTTL: cfg.FlightHandleIdleTTL, + SessionTokenTTL: cfg.FlightSessionTokenTTL, }) if err != nil { slog.Error("Failed to initialize Flight ingress.", "error", err) diff --git a/duckgres.example.yaml b/duckgres.example.yaml index f6b81b7..96b53f5 100644 --- a/duckgres.example.yaml +++ b/duckgres.example.yaml @@ -8,6 +8,10 @@ port: 5432 # Control-plane Arrow Flight SQL ingress (optional) # 0 or omitted disables Flight ingress. # flight_port: 8815 +# flight_session_idle_ttl: "10m" +# flight_session_reap_interval: "1m" +# flight_handle_idle_ttl: "15m" +# flight_session_token_ttl: "1h" # Directory for DuckDB database files (one per user) data_dir: "./data" diff --git a/main.go b/main.go index ccba543..3aecdc8 100644 --- a/main.go +++ b/main.go @@ -27,6 +27,7 @@ type FileConfig struct { FlightSessionIdleTTL string `yaml:"flight_session_idle_ttl"` // e.g., "10m" FlightSessionReapInterval string `yaml:"flight_session_reap_interval"` // e.g., "1m" FlightHandleIdleTTL string `yaml:"flight_handle_idle_ttl"` // e.g., "15m" + FlightSessionTokenTTL string `yaml:"flight_session_token_ttl"` // e.g., "1h" DataDir string `yaml:"data_dir"` TLS TLSConfig `yaml:"tls"` Users map[string]string `yaml:"users"` @@ -145,6 +146,7 @@ func main() { flightSessionIdleTTL := flag.String("flight-session-idle-ttl", "", "Flight auth session idle TTL (e.g., '10m') (env: DUCKGRES_FLIGHT_SESSION_IDLE_TTL)") flightSessionReapInterval := flag.String("flight-session-reap-interval", "", "Flight auth session reap interval (e.g., '1m') (env: DUCKGRES_FLIGHT_SESSION_REAP_INTERVAL)") flightHandleIdleTTL := flag.String("flight-handle-idle-ttl", "", "Flight prepared/query handle idle TTL (e.g., '15m') (env: DUCKGRES_FLIGHT_HANDLE_IDLE_TTL)") + flightSessionTokenTTL := flag.String("flight-session-token-ttl", "", "Flight issued session token absolute TTL (e.g., '1h') (env: DUCKGRES_FLIGHT_SESSION_TOKEN_TTL)") dataDir := flag.String("data-dir", "", "Directory for DuckDB files (env: DUCKGRES_DATA_DIR)") certFile := flag.String("cert", "", "TLS certificate file (env: DUCKGRES_CERT)") keyFile := flag.String("key", "", "TLS private key file (env: DUCKGRES_KEY)") @@ -189,6 +191,7 @@ func main() { fmt.Fprintf(os.Stderr, " DUCKGRES_FLIGHT_SESSION_IDLE_TTL Flight auth session idle TTL (default: 10m)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_FLIGHT_SESSION_REAP_INTERVAL Flight auth session reap interval (default: 1m)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_FLIGHT_HANDLE_IDLE_TTL Flight prepared/query handle idle TTL (default: 15m)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_FLIGHT_SESSION_TOKEN_TTL Flight issued session token absolute TTL (default: 1h)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_DATA_DIR Directory for DuckDB files (default: ./data)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_CERT TLS certificate file (default: ./certs/server.crt)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_KEY TLS private key file (default: ./certs/server.key)\n") @@ -271,6 +274,7 @@ func main() { FlightSessionIdleTTL: *flightSessionIdleTTL, FlightSessionReapInterval: *flightSessionReapInterval, FlightHandleIdleTTL: *flightHandleIdleTTL, + FlightSessionTokenTTL: *flightSessionTokenTTL, DataDir: *dataDir, CertFile: *certFile, KeyFile: *keyFile, diff --git a/main_test.go b/main_test.go index 6e9cd49..642f7ab 100644 --- a/main_test.go +++ b/main_test.go @@ -375,12 +375,14 @@ func TestResolveEffectiveConfigFlightIngressDurations(t *testing.T) { FlightSessionIdleTTL: "7m", FlightSessionReapInterval: "45s", FlightHandleIdleTTL: "3m", + FlightSessionTokenTTL: "2h", } env := map[string]string{ "DUCKGRES_FLIGHT_SESSION_IDLE_TTL": "9m", "DUCKGRES_FLIGHT_SESSION_REAP_INTERVAL": "30s", "DUCKGRES_FLIGHT_HANDLE_IDLE_TTL": "4m", + "DUCKGRES_FLIGHT_SESSION_TOKEN_TTL": "90m", } resolved := resolveEffectiveConfig(fileCfg, configCLIInputs{ @@ -388,10 +390,12 @@ func TestResolveEffectiveConfigFlightIngressDurations(t *testing.T) { "flight-session-idle-ttl": true, "flight-session-reap-interval": true, "flight-handle-idle-ttl": true, + "flight-session-token-ttl": true, }, FlightSessionIdleTTL: "11m", FlightSessionReapInterval: "15s", FlightHandleIdleTTL: "5m", + FlightSessionTokenTTL: "75m", }, envFromMap(env), nil) if resolved.Server.FlightSessionIdleTTL != 11*time.Minute { @@ -403,6 +407,9 @@ func TestResolveEffectiveConfigFlightIngressDurations(t *testing.T) { if resolved.Server.FlightHandleIdleTTL != 5*time.Minute { t.Fatalf("expected CLI flight_handle_idle_ttl, got %s", resolved.Server.FlightHandleIdleTTL) } + if resolved.Server.FlightSessionTokenTTL != 75*time.Minute { + t.Fatalf("expected CLI flight_session_token_ttl, got %s", resolved.Server.FlightSessionTokenTTL) + } } func TestResolveEffectiveConfigFlightIngressDurationsFromFile(t *testing.T) { @@ -410,6 +417,7 @@ func TestResolveEffectiveConfigFlightIngressDurationsFromFile(t *testing.T) { FlightSessionIdleTTL: "7m", FlightSessionReapInterval: "45s", FlightHandleIdleTTL: "3m", + FlightSessionTokenTTL: "2h", } resolved := resolveEffectiveConfig(fileCfg, configCLIInputs{}, envFromMap(nil), nil) @@ -423,6 +431,9 @@ func TestResolveEffectiveConfigFlightIngressDurationsFromFile(t *testing.T) { if resolved.Server.FlightHandleIdleTTL != 3*time.Minute { t.Fatalf("expected file flight_handle_idle_ttl, got %s", resolved.Server.FlightHandleIdleTTL) } + if resolved.Server.FlightSessionTokenTTL != 2*time.Hour { + t.Fatalf("expected file flight_session_token_ttl, got %s", resolved.Server.FlightSessionTokenTTL) + } } func TestResolveEffectiveConfigFlightIngressDurationsFromEnv(t *testing.T) { @@ -430,6 +441,7 @@ func TestResolveEffectiveConfigFlightIngressDurationsFromEnv(t *testing.T) { "DUCKGRES_FLIGHT_SESSION_IDLE_TTL": "9m", "DUCKGRES_FLIGHT_SESSION_REAP_INTERVAL": "30s", "DUCKGRES_FLIGHT_HANDLE_IDLE_TTL": "4m", + "DUCKGRES_FLIGHT_SESSION_TOKEN_TTL": "30m", } resolved := resolveEffectiveConfig(nil, configCLIInputs{}, envFromMap(env), nil) @@ -443,6 +455,9 @@ func TestResolveEffectiveConfigFlightIngressDurationsFromEnv(t *testing.T) { if resolved.Server.FlightHandleIdleTTL != 4*time.Minute { t.Fatalf("expected env flight_handle_idle_ttl, got %s", resolved.Server.FlightHandleIdleTTL) } + if resolved.Server.FlightSessionTokenTTL != 30*time.Minute { + t.Fatalf("expected env flight_session_token_ttl, got %s", resolved.Server.FlightSessionTokenTTL) + } } func TestResolveEffectiveConfigInvalidFlightPortEnv(t *testing.T) { diff --git a/server/flightsqlingress/ingress.go b/server/flightsqlingress/ingress.go index 7fade6e..a8fd320 100644 --- a/server/flightsqlingress/ingress.go +++ b/server/flightsqlingress/ingress.go @@ -2,10 +2,12 @@ package flightsqlingress import ( "context" + "crypto/rand" "crypto/subtle" "crypto/tls" "database/sql" "encoding/base64" + "encoding/hex" "fmt" "log/slog" "net" @@ -30,11 +32,12 @@ import ( ) const ( - flightBatchSize = 1024 - defaultFlightSessionIdleTTL = 10 * time.Minute - defaultFlightSessionReapTick = 1 * time.Minute - defaultFlightHandleIdleTTL = 15 * time.Minute - defaultFlightClientIDHeaderKey = "x-duckgres-client-id" + flightBatchSize = 1024 + defaultFlightSessionIdleTTL = 10 * time.Minute + defaultFlightSessionReapTick = 1 * time.Minute + defaultFlightHandleIdleTTL = 15 * time.Minute + defaultFlightSessionTokenTTL = 1 * time.Hour + defaultFlightSessionHeaderKey = "x-duckgres-session" ) const ( @@ -50,6 +53,7 @@ type Config struct { SessionIdleTTL time.Duration SessionReapTick time.Duration HandleIdleTTL time.Duration + SessionTokenTTL time.Duration } type SessionProvider interface { @@ -109,8 +113,11 @@ func NewFlightIngress(host string, port int, tlsConfig *tls.Config, users map[st if cfg.HandleIdleTTL <= 0 { cfg.HandleIdleTTL = defaultFlightHandleIdleTTL } + if cfg.SessionTokenTTL <= 0 { + cfg.SessionTokenTTL = defaultFlightSessionTokenTTL + } - store := newFlightAuthSessionStore(provider, cfg.SessionIdleTTL, cfg.SessionReapTick, cfg.HandleIdleTTL, opts) + store := newFlightAuthSessionStore(provider, cfg.SessionIdleTTL, cfg.SessionReapTick, cfg.HandleIdleTTL, cfg.SessionTokenTTL, opts) handler, err := NewControlPlaneFlightSQLHandler(store, users) if err != nil { _ = ln.Close() @@ -202,14 +209,49 @@ func (h *ControlPlaneFlightSQLHandler) sessionFromContext(ctx context.Context) ( return nil, status.Error(codes.Unauthenticated, "missing metadata") } + if h.sessions == nil { + return nil, status.Error(codes.Unavailable, "session store is not configured") + } + + username, err := h.authenticateBasicCredentials(md) + if err != nil { + return nil, err + } + + if sessionToken := incomingSessionToken(md); sessionToken != "" { + s, ok := h.sessions.GetByToken(sessionToken) + if !ok { + return nil, status.Error(codes.Unauthenticated, "session not found") + } + + if username != s.username { + return nil, status.Error(codes.PermissionDenied, "session token does not match authenticated user") + } + + setSessionTokenMetadata(ctx, sessionToken) + s.touch() + return s, nil + } + + s, err := h.sessions.Create(ctx, username) + if err != nil { + return nil, status.Errorf(codes.Unavailable, "create bootstrap session: %v", err) + } + + setSessionTokenMetadata(ctx, s.token) + s.touch() + return s, nil +} + +func (h *ControlPlaneFlightSQLHandler) authenticateBasicCredentials(md metadata.MD) (string, error) { authHeaders := md.Get("authorization") if len(authHeaders) == 0 { - return nil, status.Error(codes.Unauthenticated, "missing authorization header") + return "", status.Error(codes.Unauthenticated, "missing authorization header") } username, password, err := parseBasicCredentials(authHeaders[0]) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return "", status.Error(codes.Unauthenticated, err.Error()) } expected, userFound := h.users[username] @@ -218,16 +260,26 @@ func (h *ControlPlaneFlightSQLHandler) sessionFromContext(ctx context.Context) ( } passMatch := subtle.ConstantTimeCompare([]byte(password), []byte(expected)) == 1 if !userFound || !passMatch { - return nil, status.Error(codes.Unauthenticated, "invalid credentials") + return "", status.Error(codes.Unauthenticated, "invalid credentials") } + return username, nil +} - sessionKey := flightAuthSessionKey(ctx, username) - s, err := h.sessions.GetOrCreate(ctx, sessionKey, username) - if err != nil { - return nil, status.Errorf(codes.Unavailable, "create session: %v", err) +func incomingSessionToken(md metadata.MD) string { + values := md.Get(defaultFlightSessionHeaderKey) + if len(values) == 0 { + return "" } - s.touch() - return s, nil + return strings.TrimSpace(values[0]) +} + +func setSessionTokenMetadata(ctx context.Context, sessionToken string) { + if sessionToken == "" { + return + } + md := metadata.Pairs(defaultFlightSessionHeaderKey, sessionToken) + _ = grpc.SetHeader(ctx, md) + _ = grpc.SetTrailer(ctx, md) } func (h *ControlPlaneFlightSQLHandler) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -334,9 +386,13 @@ func (h *ControlPlaneFlightSQLHandler) DoPutCommandStatementUpdate(ctx context.C return 0, status.Errorf(codes.InvalidArgument, "failed to execute update: %v", err) } + return rowsAffectedOrError(res) +} + +func rowsAffectedOrError(res server.ExecResult) (int64, error) { affected, err := res.RowsAffected() if err != nil { - return 0, nil + return 0, status.Errorf(codes.Internal, "failed to fetch affected row count: %v", err) } return affected, nil } @@ -789,12 +845,15 @@ type flightQueryHandle struct { type flightClientSession struct { pid int32 + token string username string executor *server.FlightExecutor lastUsed atomic.Int64 - counter atomic.Uint64 - streams atomic.Int32 + // tokenIssuedAt stores when this token was issued; used for absolute token TTL. + tokenIssuedAt atomic.Int64 + counter atomic.Uint64 + streams atomic.Int32 opMu sync.Mutex @@ -942,6 +1001,7 @@ type flightAuthSessionStore struct { idleTTL time.Duration reapInterval time.Duration handleIdleTTL time.Duration + tokenTTL time.Duration hooks Hooks isMaxWorkerFn func(error) bool @@ -949,7 +1009,8 @@ type flightAuthSessionStore struct { destroySessionFn func(int32) mu sync.RWMutex - sessions map[string]*flightClientSession + sessions map[string]*flightClientSession // session token -> session + byKey map[string]string // auth bootstrap key -> session token stopOnce sync.Once stopCh chan struct{} @@ -968,7 +1029,7 @@ func (r *lockedRowSet) Close() error { return err } -func newFlightAuthSessionStore(provider SessionProvider, idleTTL, reapInterval, handleIdleTTL time.Duration, opts Options) *flightAuthSessionStore { +func newFlightAuthSessionStore(provider SessionProvider, idleTTL, reapInterval, handleIdleTTL, tokenTTL time.Duration, opts Options) *flightAuthSessionStore { createFn := func(context.Context, string) (int32, *server.FlightExecutor, error) { return 0, nil, fmt.Errorf("session provider is not configured") } @@ -987,11 +1048,13 @@ func newFlightAuthSessionStore(provider SessionProvider, idleTTL, reapInterval, idleTTL: idleTTL, reapInterval: reapInterval, handleIdleTTL: handleIdleTTL, + tokenTTL: tokenTTL, hooks: opts.Hooks, isMaxWorkerFn: isMaxWorkerFn, createSessionFn: createFn, destroySessionFn: destroyFn, sessions: make(map[string]*flightClientSession), + byKey: make(map[string]string), stopCh: make(chan struct{}), doneCh: make(chan struct{}), } @@ -999,6 +1062,14 @@ func newFlightAuthSessionStore(provider SessionProvider, idleTTL, reapInterval, return s } +func (s *flightAuthSessionStore) Create(ctx context.Context, username string) (*flightClientSession, error) { + bootstrapNonce, err := generateSessionIdentityToken() + if err != nil { + return nil, fmt.Errorf("generate bootstrap nonce: %w", err) + } + return s.GetOrCreate(ctx, "bootstrap|"+username+"|"+bootstrapNonce, username) +} + func (s *flightAuthSessionStore) notifySessionCountChanged(count int) { if s.hooks.OnSessionCountChanged != nil { s.hooks.OnSessionCountChanged(count) @@ -1021,7 +1092,7 @@ func (s *flightAuthSessionStore) notifyMaxWorkerRetry(outcome string) { } func (s *flightAuthSessionStore) GetOrCreate(ctx context.Context, key, username string) (*flightClientSession, error) { - existing, ok := s.getExisting(key) + existing, ok := s.getExistingByKey(key) if ok { existing.touch() return existing, nil @@ -1051,16 +1122,38 @@ func (s *flightAuthSessionStore) GetOrCreate(ctx context.Context, key, username if err != nil { return nil, err } + + token, tokenErr := generateSessionIdentityToken() + if tokenErr != nil { + s.destroySessionFn(pid) + return nil, fmt.Errorf("generate session identity token: %w", tokenErr) + } created := newFlightClientSession(pid, username, executor) + created.token = token s.mu.Lock() - if existing, ok := s.sessions[key]; ok { + s.ensureMapsLocked() + if existing, ok := s.getExistingByKeyLocked(key); ok { s.mu.Unlock() s.destroySessionFn(pid) existing.touch() return existing, nil } - s.sessions[key] = created + for { + if _, exists := s.sessions[created.token]; !exists { + break + } + token, tokenErr = generateSessionIdentityToken() + if tokenErr != nil { + s.mu.Unlock() + s.destroySessionFn(pid) + return nil, fmt.Errorf("generate session identity token: %w", tokenErr) + } + created.token = token + } + s.sessions[created.token] = created + created.tokenIssuedAt.Store(time.Now().UnixNano()) + s.byKey[key] = created.token sessionCount := len(s.sessions) s.mu.Unlock() s.notifySessionCountChanged(sessionCount) @@ -1068,20 +1161,102 @@ func (s *flightAuthSessionStore) GetOrCreate(ctx context.Context, key, username return created, nil } -func (s *flightAuthSessionStore) getExisting(key string) (*flightClientSession, bool) { +func (s *flightAuthSessionStore) GetByToken(token string) (*flightClientSession, bool) { + token = strings.TrimSpace(token) + if token == "" { + return nil, false + } + + var ( + session *flightClientSession + ok bool + expiredSession *flightClientSession + postExpireCount int + tokenIssuedAtRaw int64 + ) + + s.mu.Lock() + s.ensureMapsLocked() + session, ok = s.sessions[token] + if !ok { + s.mu.Unlock() + return nil, false + } + + tokenIssuedAtRaw = session.tokenIssuedAt.Load() + if s.tokenTTL > 0 && tokenIssuedAtRaw > 0 { + tokenAge := time.Since(time.Unix(0, tokenIssuedAtRaw)) + if tokenAge >= s.tokenTTL { + delete(s.sessions, token) + s.removeByKeyForTokenLocked(token) + expiredSession = session + postExpireCount = len(s.sessions) + destroyFn := s.destroySessionFn + s.mu.Unlock() + + if destroyFn != nil { + destroyFn(expiredSession.pid) + } + s.notifySessionCountChanged(postExpireCount) + return nil, false + } + } + s.mu.Unlock() + return session, true +} + +func (s *flightAuthSessionStore) getExistingByKey(key string) (*flightClientSession, bool) { s.mu.RLock() - existing, ok := s.sessions[key] + existing, ok := s.getExistingByKeyLocked(key) s.mu.RUnlock() + if !ok { + s.pruneStaleByKey(key) + } + return existing, ok +} + +func (s *flightAuthSessionStore) getExistingByKeyLocked(key string) (*flightClientSession, bool) { + token, ok := s.byKey[key] + if !ok { + return nil, false + } + existing, ok := s.sessions[token] return existing, ok } +// pruneStaleByKey removes a bootstrap key only if it still points to a missing +// session token. This method must hold an exclusive lock because it mutates maps. +func (s *flightAuthSessionStore) pruneStaleByKey(key string) { + s.mu.Lock() + defer s.mu.Unlock() + s.ensureMapsLocked() + + token, ok := s.byKey[key] + if !ok { + return + } + if _, exists := s.sessions[token]; exists { + return + } + delete(s.byKey, key) +} + +func (s *flightAuthSessionStore) ensureMapsLocked() { + if s.sessions == nil { + s.sessions = make(map[string]*flightClientSession) + } + if s.byKey == nil { + s.byKey = make(map[string]string) + } +} + func (s *flightAuthSessionStore) waitForExisting(ctx context.Context, key string, timeout time.Duration) (*flightClientSession, bool) { deadline := time.Now().Add(timeout) ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for time.Now().Before(deadline) { - if existing, ok := s.getExisting(key); ok { + if existing, ok := s.getExistingByKey(key); ok { return existing, true } select { @@ -1090,7 +1265,7 @@ func (s *flightAuthSessionStore) waitForExisting(ctx context.Context, key string case <-ticker.C: } } - return s.getExisting(key) + return s.getExistingByKey(key) } func (s *flightAuthSessionStore) Close() { @@ -1099,11 +1274,13 @@ func (s *flightAuthSessionStore) Close() { <-s.doneCh s.mu.Lock() + s.ensureMapsLocked() sessions := make([]*flightClientSession, 0, len(s.sessions)) for _, cs := range s.sessions { sessions = append(sessions, cs) } s.sessions = make(map[string]*flightClientSession) + s.byKey = make(map[string]string) s.mu.Unlock() s.notifySessionCountChanged(0) @@ -1122,7 +1299,8 @@ func (s *flightAuthSessionStore) reapIdle(now time.Time, trigger string) int { sessionCount := 0 s.mu.Lock() - for key, cs := range s.sessions { + s.ensureMapsLocked() + for token, cs := range s.sessions { cs.reapStaleHandles(now, s.handleIdleTTL) last := time.Unix(0, cs.lastUsed.Load()) @@ -1138,7 +1316,8 @@ func (s *flightAuthSessionStore) reapIdle(now time.Time, trigger string) int { if cs.queryCount() > 0 { continue } - delete(s.sessions, key) + delete(s.sessions, token) + s.removeByKeyForTokenLocked(token) stale = append(stale, cs) } sessionCount = len(s.sessions) @@ -1155,6 +1334,14 @@ func (s *flightAuthSessionStore) reapIdle(now time.Time, trigger string) int { return reaped } +func (s *flightAuthSessionStore) removeByKeyForTokenLocked(token string) { + for key, mappedToken := range s.byKey { + if mappedToken == token { + delete(s.byKey, key) + } + } +} + func (s *flightAuthSessionStore) reapLoop() { ticker := time.NewTicker(s.reapInterval) defer ticker.Stop() @@ -1201,26 +1388,26 @@ func parseBasicCredentials(authHeader string) (username, password string, err er } func flightAuthSessionKey(ctx context.Context, username string) string { - clientID := "" - if md, ok := metadata.FromIncomingContext(ctx); ok { - if values := md.Get(defaultFlightClientIDHeaderKey); len(values) > 0 { - clientID = strings.TrimSpace(values[0]) - } - } - if clientID == "" { - clientID = "unknown" - if p, ok := peer.FromContext(ctx); ok && p != nil && p.Addr != nil { - host, _, err := net.SplitHostPort(p.Addr.String()) - if err == nil && host != "" { - clientID = host - } else { - clientID = p.Addr.String() - } + clientID := "unknown" + if p, ok := peer.FromContext(ctx); ok && p != nil && p.Addr != nil { + host, _, err := net.SplitHostPort(p.Addr.String()) + if err == nil && host != "" { + clientID = host + } else { + clientID = p.Addr.String() } } return username + "|" + clientID } +func generateSessionIdentityToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + func getQuerySchema(ctx context.Context, session *flightClientSession, query string) (*arrow.Schema, error) { q := strings.TrimRight(strings.TrimSpace(query), ";") upper := strings.ToUpper(q) diff --git a/server/flightsqlingress/ingress_test.go b/server/flightsqlingress/ingress_test.go index 330006e..42aac86 100644 --- a/server/flightsqlingress/ingress_test.go +++ b/server/flightsqlingress/ingress_test.go @@ -19,6 +19,15 @@ import ( var errTestMaxWorkersReached = errors.New("max workers reached") +type testExecResult struct { + affected int64 + err error +} + +func (r testExecResult) RowsAffected() (int64, error) { + return r.affected, r.err +} + func TestParseBasicCredentials(t *testing.T) { token := base64.StdEncoding.EncodeToString([]byte("postgres:postgres")) user, pass, err := parseBasicCredentials("Basic " + token) @@ -57,6 +66,23 @@ func TestSupportsLimit(t *testing.T) { } } +func TestRowsAffectedOrErrorPropagatesRowsAffectedError(t *testing.T) { + _, err := rowsAffectedOrError(testExecResult{err: errors.New("not available")}) + if err == nil { + t.Fatalf("expected rowsAffectedOrError to return an error") + } +} + +func TestRowsAffectedOrErrorReturnsAffectedCount(t *testing.T) { + affected, err := rowsAffectedOrError(testExecResult{affected: 42}) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if affected != 42 { + t.Fatalf("expected affected=42, got %d", affected) + } +} + func TestFlightAuthSessionKeyStableAcrossPeerPorts(t *testing.T) { ctx1 := peer.NewContext(context.Background(), &peer.Peer{ Addr: &net.TCPAddr{IP: net.ParseIP("203.0.113.10"), Port: 40000}, @@ -76,18 +102,246 @@ func TestFlightAuthSessionKeyStableAcrossPeerPorts(t *testing.T) { } } -func TestFlightAuthSessionKeyMetadataClientOverride(t *testing.T) { +func TestFlightAuthSessionKeyDoesNotTrustMetadataClientOverride(t *testing.T) { base := peer.NewContext(context.Background(), &peer.Peer{ Addr: &net.TCPAddr{IP: net.ParseIP("203.0.113.10"), Port: 45555}, }) ctx := metadata.NewIncomingContext(base, metadata.Pairs("x-duckgres-client-id", "worker-a")) key := flightAuthSessionKey(ctx, "postgres") - if !strings.Contains(key, "worker-a") { - t.Fatalf("expected session key to include metadata client id, got %q", key) + if strings.Contains(key, "worker-a") { + t.Fatalf("session key should ignore untrusted metadata client id: %q", key) } if strings.Contains(key, "45555") { - t.Fatalf("session key should ignore peer source port when client id is provided: %q", key) + t.Fatalf("session key should not include peer source port: %q", key) + } +} + +func TestSessionFromContextRejectsServerIssuedSessionTokenWithoutBasicAuth(t *testing.T) { + store := &flightAuthSessionStore{ + sessions: map[string]*flightClientSession{ + "issued-token": newFlightClientSession(1234, "postgres", nil), + }, + } + + h, err := NewControlPlaneFlightSQLHandler(store, map[string]string{"postgres": "postgres"}) + if err != nil { + t.Fatalf("failed to construct handler: %v", err) + } + + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-duckgres-session", "issued-token")) + if _, err := h.sessionFromContext(ctx); err == nil { + t.Fatalf("expected token-only auth to be rejected") + } +} + +func TestSessionFromContextRejectsUnknownSessionTokenEvenWithBasicAuth(t *testing.T) { + store := &flightAuthSessionStore{ + createSessionFn: func(context.Context, string) (int32, *server.FlightExecutor, error) { + return 9876, nil, nil + }, + destroySessionFn: func(int32) {}, + sessions: make(map[string]*flightClientSession), + } + + h, err := NewControlPlaneFlightSQLHandler(store, map[string]string{"postgres": "postgres"}) + if err != nil { + t.Fatalf("failed to construct handler: %v", err) + } + + token := base64.StdEncoding.EncodeToString([]byte("postgres:postgres")) + ctx := metadata.NewIncomingContext( + context.Background(), + metadata.Pairs( + "x-duckgres-session", "missing-token", + "authorization", "Basic "+token, + ), + ) + + if _, err := h.sessionFromContext(ctx); err == nil { + t.Fatalf("expected unknown session token to be rejected") + } +} + +func TestSessionFromContextAcceptsServerIssuedSessionTokenWithBasicAuth(t *testing.T) { + s := newFlightClientSession(1234, "postgres", nil) + s.token = "issued-token" + store := &flightAuthSessionStore{ + sessions: map[string]*flightClientSession{ + "issued-token": s, + }, + } + + h, err := NewControlPlaneFlightSQLHandler(store, map[string]string{"postgres": "postgres"}) + if err != nil { + t.Fatalf("failed to construct handler: %v", err) + } + + token := base64.StdEncoding.EncodeToString([]byte("postgres:postgres")) + ctx := metadata.NewIncomingContext( + context.Background(), + metadata.Pairs( + "x-duckgres-session", "issued-token", + "authorization", "Basic "+token, + ), + ) + + got, err := h.sessionFromContext(ctx) + if err != nil { + t.Fatalf("expected token+basic auth to succeed, got error: %v", err) + } + if got == nil { + t.Fatalf("expected non-nil session") + } + if got.username != "postgres" { + t.Fatalf("expected postgres session, got %q", got.username) + } +} + +func TestSessionFromContextWithoutTokenCreatesDistinctSessions(t *testing.T) { + var createCalls atomic.Int32 + store := &flightAuthSessionStore{ + createSessionFn: func(context.Context, string) (int32, *server.FlightExecutor, error) { + return createCalls.Add(1), nil, nil + }, + destroySessionFn: func(int32) {}, + sessions: make(map[string]*flightClientSession), + byKey: make(map[string]string), + } + + h, err := NewControlPlaneFlightSQLHandler(store, map[string]string{"postgres": "postgres"}) + if err != nil { + t.Fatalf("failed to construct handler: %v", err) + } + + token := base64.StdEncoding.EncodeToString([]byte("postgres:postgres")) + base := peer.NewContext(context.Background(), &peer.Peer{ + Addr: &net.TCPAddr{IP: net.ParseIP("203.0.113.10"), Port: 45555}, + }) + ctx := metadata.NewIncomingContext(base, metadata.Pairs("authorization", "Basic "+token)) + + s1, err := h.sessionFromContext(ctx) + if err != nil { + t.Fatalf("first call failed: %v", err) + } + s2, err := h.sessionFromContext(ctx) + if err != nil { + t.Fatalf("second call failed: %v", err) + } + + if s1 == nil || s2 == nil { + t.Fatalf("expected non-nil sessions") + } + if s1 == s2 { + t.Fatalf("expected distinct sessions without session token") + } + if createCalls.Load() != 2 { + t.Fatalf("expected two independent session creations, got %d", createCalls.Load()) + } +} + +func TestFlightAuthSessionStoreGetExistingByKeyConcurrentStaleEntry(t *testing.T) { + store := &flightAuthSessionStore{ + sessions: make(map[string]*flightClientSession), + byKey: map[string]string{ + "stale-key": "missing-token", + }, + } + + const workers = 24 + const iterations = 1000 + + start := make(chan struct{}) + errCh := make(chan string, workers) + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + for j := 0; j < iterations; j++ { + if _, ok := store.getExistingByKey("stale-key"); ok { + select { + case errCh <- "expected stale key lookup to miss": + default: + } + return + } + } + }() + } + close(start) + wg.Wait() + close(errCh) + if msg, ok := <-errCh; ok { + t.Fatalf("%s", msg) + } + + store.mu.RLock() + _, stillPresent := store.byKey["stale-key"] + store.mu.RUnlock() + if stillPresent { + t.Fatalf("expected stale key mapping to be pruned") + } +} + +func TestSessionFromContextRejectsExpiredSessionToken(t *testing.T) { + s := newFlightClientSession(1234, "postgres", nil) + s.token = "issued-token" + s.tokenIssuedAt.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + + store := &flightAuthSessionStore{ + tokenTTL: time.Hour, + sessions: map[string]*flightClientSession{ + "issued-token": s, + }, + } + + h, err := NewControlPlaneFlightSQLHandler(store, map[string]string{"postgres": "postgres"}) + if err != nil { + t.Fatalf("failed to construct handler: %v", err) + } + + token := base64.StdEncoding.EncodeToString([]byte("postgres:postgres")) + ctx := metadata.NewIncomingContext( + context.Background(), + metadata.Pairs( + "x-duckgres-session", "issued-token", + "authorization", "Basic "+token, + ), + ) + + if _, err := h.sessionFromContext(ctx); err == nil { + t.Fatalf("expected expired session token to be rejected") + } +} + +func TestSessionFromContextRejectsTokenUserMismatch(t *testing.T) { + store := &flightAuthSessionStore{ + sessions: map[string]*flightClientSession{ + "issued-token": newFlightClientSession(1234, "postgres", nil), + }, + } + + h, err := NewControlPlaneFlightSQLHandler(store, map[string]string{ + "postgres": "postgres", + "alice": "alice", + }) + if err != nil { + t.Fatalf("failed to construct handler: %v", err) + } + + token := base64.StdEncoding.EncodeToString([]byte("alice:alice")) + ctx := metadata.NewIncomingContext( + context.Background(), + metadata.Pairs( + "x-duckgres-session", "issued-token", + "authorization", "Basic "+token, + ), + ) + + if _, err := h.sessionFromContext(ctx); err == nil { + t.Fatalf("expected token/user mismatch to be rejected") } } diff --git a/server/server.go b/server/server.go index 7ecc360..4bafc13 100644 --- a/server/server.go +++ b/server/server.go @@ -101,8 +101,13 @@ type Config struct { // FlightHandleIdleTTL controls stale prepared/query handle cleanup inside a // Flight auth session. FlightHandleIdleTTL time.Duration - DataDir string - Users map[string]string // username -> password + + // FlightSessionTokenTTL controls the absolute lifetime of issued + // x-duckgres-session tokens. Expired tokens are rejected and require + // a fresh bootstrap request. + FlightSessionTokenTTL time.Duration + DataDir string + Users map[string]string // username -> password // TLS configuration (required unless ACME is configured) TLSCertFile string // Path to TLS certificate file diff --git a/tests/controlplane/flight_ingress_test.go b/tests/controlplane/flight_ingress_test.go index c7a4ab2..1148c6b 100644 --- a/tests/controlplane/flight_ingress_test.go +++ b/tests/controlplane/flight_ingress_test.go @@ -7,6 +7,7 @@ import ( "crypto/tls" "encoding/base64" "fmt" + "strings" "sync" "testing" "time" @@ -27,6 +28,11 @@ func flightAuthContext(username, password string) context.Context { return metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorization", "Basic "+token)) } +func basicAuthHeader(username, password string) string { + token := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + return "Basic " + token +} + func newFlightClient(t *testing.T, port int) *flightsql.Client { t.Helper() addr := fmt.Sprintf("127.0.0.1:%d", port) @@ -44,14 +50,7 @@ func newFlightClient(t *testing.T, port int) *flightsql.Client { return client } -func requireGetTablesIncludeSchema(t *testing.T, flightPort int) error { - client := newFlightClient(t, flightPort) - defer func() { - _ = client.Close() - }() - ctx, cancel := context.WithTimeout(flightAuthContext("testuser", "testpass"), 20*time.Second) - defer cancel() - +func requireGetTablesIncludeSchema(t *testing.T, client *flightsql.Client, ctx context.Context) error { info, err := client.GetTables(ctx, &flightsql.GetTablesOpts{IncludeSchema: true}) if err != nil { return fmt.Errorf("GetTables(include_schema=true) failed: %w", err) @@ -76,7 +75,7 @@ func requireGetTablesIncludeSchema(t *testing.T, flightPort int) error { } for reader.Next() { - record := reader.Record() + record := reader.RecordBatch() names, ok := record.Column(2).(*array.String) if !ok { reader.Release() @@ -135,8 +134,34 @@ func TestFlightIngressIncludeSchemaLowWorkerRegression(t *testing.T) { wg.Add(1) go func(workerID int) { defer wg.Done() + client := newFlightClient(t, h.flightPort) + defer func() { _ = client.Close() }() + + bootstrapCtx, bootstrapCancel := context.WithTimeout(flightAuthContext("testuser", "testpass"), 20*time.Second) + var respHeader metadata.MD + _, bootstrapErr := client.GetTables(bootstrapCtx, &flightsql.GetTablesOpts{}, grpc.Header(&respHeader)) + bootstrapCancel() + if bootstrapErr != nil { + errCh <- fmt.Errorf("worker %d bootstrap failed: %w", workerID, bootstrapErr) + return + } + sessionTokens := respHeader.Get("x-duckgres-session") + if len(sessionTokens) == 0 || strings.TrimSpace(sessionTokens[0]) == "" { + errCh <- fmt.Errorf("worker %d bootstrap missing x-duckgres-session header", workerID) + return + } + authHeader := basicAuthHeader("testuser", "testpass") + token := strings.TrimSpace(sessionTokens[0]) + for i := 0; i < iterationsPerGoroutine; i++ { - if err := requireGetTablesIncludeSchema(t, h.flightPort); err != nil { + iterBaseCtx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs( + "authorization", authHeader, + "x-duckgres-session", token, + )) + iterCtx, iterCancel := context.WithTimeout(iterBaseCtx, 20*time.Second) + err := requireGetTablesIncludeSchema(t, client, iterCtx) + iterCancel() + if err != nil { errCh <- fmt.Errorf("worker %d iteration %d failed: %w", workerID, i, err) return } @@ -151,3 +176,48 @@ func TestFlightIngressIncludeSchemaLowWorkerRegression(t *testing.T) { t.Fatalf("flight include_schema regression: %v\nLogs:\n%s", err, h.logBuf.String()) } } + +func TestFlightIngressServerIssuedSessionTokenRequiresBasicAuth(t *testing.T) { + h := startControlPlane(t, cpOpts{ + flightPort: freePort(t), + maxWorkers: 1, + }) + + client1 := newFlightClient(t, h.flightPort) + defer func() { _ = client1.Close() }() + + var respHeader metadata.MD + ctx1, cancel1 := context.WithTimeout(flightAuthContext("testuser", "testpass"), 20*time.Second) + defer cancel1() + + if _, err := client1.GetTables(ctx1, &flightsql.GetTablesOpts{}, grpc.Header(&respHeader)); err != nil { + t.Fatalf("bootstrap GetTables with basic auth failed: %v", err) + } + + sessionTokens := respHeader.Get("x-duckgres-session") + if len(sessionTokens) == 0 || strings.TrimSpace(sessionTokens[0]) == "" { + t.Fatalf("expected server-issued x-duckgres-session header, got %v", respHeader) + } + + client2 := newFlightClient(t, h.flightPort) + defer func() { _ = client2.Close() }() + + tokenCtx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("x-duckgres-session", sessionTokens[0])) + ctx2, cancel2 := context.WithTimeout(tokenCtx, 20*time.Second) + defer cancel2() + + if _, err := client2.GetTables(ctx2, &flightsql.GetTablesOpts{}); err == nil { + t.Fatalf("expected token-only GetTables to fail without basic auth") + } + + tokenAndAuthCtx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs( + "x-duckgres-session", sessionTokens[0], + "authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("testuser:testpass")), + )) + ctx3, cancel3 := context.WithTimeout(tokenAndAuthCtx, 20*time.Second) + defer cancel3() + + if _, err := client2.GetTables(ctx3, &flightsql.GetTablesOpts{}); err != nil { + t.Fatalf("token+basic GetTables failed: %v", err) + } +}