diff --git a/.env b/.env index b11f0f030c..4fb2b9afc9 100644 --- a/.env +++ b/.env @@ -57,7 +57,15 @@ ADBC_JDBC_POSTGRESQL_PASSWORD=password ADBC_JDBC_POSTGRESQL_DATABASE=postgres ADBC_POSTGRESQL_TEST_URI="postgresql://localhost:5432/postgres?user=postgres&password=password" ADBC_SQLITE_FLIGHTSQL_URI=grpc+tcp://localhost:8080 -ADBC_TEST_FLIGHTSQL_URI=grpc+tcp://localhost:41414 +ADBC_TEST_FLIGHTSQL_URI=grpc+tls://localhost:41414 ADBC_GIZMOSQL_URI=grpc+tcp://localhost:31337 ADBC_GIZMOSQL_USER=adbc_test_user ADBC_GIZMOSQL_PASSWORD=adbc_test_password + +# OAuth test server configuration +# OAuth token endpoint (oauthserver on port 8181) +ADBC_OAUTH_TOKEN_URI=http://localhost:8181/token +ADBC_OAUTH_CLIENT_ID=test-client +ADBC_OAUTH_CLIENT_SECRET=test-secret +ADBC_OAUTH_SUBJECT_TOKEN=test-subject-token +ADBC_OAUTH_SKIP_VERIFY=true diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml index 814e17654d..2b3bb31179 100644 --- a/.github/workflows/native-unix.yml +++ b/.github/workflows/native-unix.yml @@ -728,7 +728,7 @@ jobs: docs/source/python/recipe/*.py - name: Test Recipes (Python) run: | - docker compose up --detach --wait dremio flightsql-sqlite-test postgres-test gizmosql-test + docker compose up --detach --wait dremio flightsql-sqlite-test postgres-test gizmosql-test oauth-server flightsql-test docker compose run --rm dremio-init export ADBC_CPP_RECIPE_BIN=~/local/bin # Needed for the combined C++/Python driver example diff --git a/ci/docker/oauth-server.dockerfile b/ci/docker/oauth-server.dockerfile new file mode 100644 index 0000000000..21844d693c --- /dev/null +++ b/ci/docker/oauth-server.dockerfile @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Simple OAuth 2.0 test server for ADBC FlightSQL OAuth testing +ARG GO +FROM golang:${GO} +EXPOSE 8181 diff --git a/compose.yaml b/compose.yaml index 358fd392c1..75d884fd7e 100644 --- a/compose.yaml +++ b/compose.yaml @@ -279,7 +279,7 @@ services: args: GO: ${GO} healthcheck: - test: ["CMD", "curl", "--http2-prior-knowledge", "-XPOST", "-H", "content-type: application/grpc", "localhost:41414"] + test: ["CMD", "curl", "-k", "--http2", "-XPOST", "-H", "content-type: application/grpc", "https://localhost:41414"] interval: 5s timeout: 30s retries: 3 @@ -288,8 +288,35 @@ services: - "41414:41414" volumes: - .:/adbc:delegated + depends_on: + oauth-server: + condition: service_healthy + command: >- + /bin/bash -c "cd /adbc/go/adbc && go run ./driver/flightsql/cmd/testserver -host 0.0.0.0 -port 41414 -token-prefix oauth- -tls" + + # OAuth test server for FlightSQL OAuth authentication testing + oauth-server: + container_name: adbc-oauth-server + image: ${REPO}:adbc-oauth-server + build: + context: . + cache_from: + - ${REPO}:adbc-oauth-server + dockerfile: ci/docker/oauth-server.dockerfile + args: + GO: ${GO} + healthcheck: + test: ["CMD", "curl", "--fail", "http://localhost:8181/health"] + interval: 5s + timeout: 10s + retries: 3 + start_period: 30s + ports: + - "8181:8181" + volumes: + - .:/adbc:delegated command: >- - /bin/bash -c "cd /adbc/go/adbc && go run ./driver/flightsql/cmd/testserver -host 0.0.0.0 -port 41414" + /bin/bash -c "cd /adbc/go/adbc && go run ./driver/flightsql/cmd/oauthserver -host 0.0.0.0 -port 8181 -client-id test-client -client-secret test-secret" flightsql-sqlite-test: image: ${REPO}:golang-${GO}-sqlite-flightsql diff --git a/docs/source/driver/flight_sql.rst b/docs/source/driver/flight_sql.rst index 11cb7b1918..6c1ba7e9ac 100644 --- a/docs/source/driver/flight_sql.rst +++ b/docs/source/driver/flight_sql.rst @@ -215,46 +215,76 @@ OAuth 2.0 Options Supported configurations to obtain tokens using OAuth 2.0 authentication flows. ``adbc.flight.sql.oauth.flow`` - Specifies the OAuth 2.0 flow type to use. Possible values: ``client_credentials``, ``token_exchange`` + Specifies the OAuth 2.0 flow type to use. Possible values: ``client_credentials``, ``token_exchange`` + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_FLOW`, + :class:`adbc_driver_flightsql.OAuthFlowType` ``adbc.flight.sql.oauth.client_id`` - Unique identifier issued to the client application by the authorization server + Unique identifier issued to the client application by the authorization server + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_CLIENT_ID` ``adbc.flight.sql.oauth.client_secret`` - Secret associated to the client_id. Used to authenticate the client application to the authorization server + Secret associated to the client_id. Used to authenticate the client application to the authorization server + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_CLIENT_SECRET` ``adbc.flight.sql.oauth.token_uri`` - The endpoint URL where the client application requests tokens from the authorization server + The endpoint URL where the client application requests tokens from the authorization server + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_TOKEN_URI` ``adbc.flight.sql.oauth.scope`` - Space-separated list of permissions that the client is requesting access to (e.g ``"read.all offline_access"``) + Space-separated list of permissions that the client is requesting access to (e.g ``"read.all offline_access"``) + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_SCOPE` ``adbc.flight.sql.oauth.exchange.subject_token`` - The security token that the client application wants to exchange + The security token that the client application wants to exchange + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_EXCHANGE_SUBJECT_TOKEN` ``adbc.flight.sql.oauth.exchange.subject_token_type`` - Identifier for the type of the subject token. - Check list below for supported token types. + Identifier for the type of the subject token. + Check list below for supported token types. + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE`, + :class:`adbc_driver_flightsql.OAuthTokenType` ``adbc.flight.sql.oauth.exchange.actor_token`` - A security token that represents the identity of the acting party + A security token that represents the identity of the acting party + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_EXCHANGE_ACTOR_TOKEN` ``adbc.flight.sql.oauth.exchange.actor_token_type`` - Identifier for the type of the actor token. - Check list below for supported token types. + Identifier for the type of the actor token. + Check list below for supported token types. + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE`, + :class:`adbc_driver_flightsql.OAuthTokenType` + ``adbc.flight.sql.oauth.exchange.aud`` - The intended audience for the requested security token + The intended audience for the requested security token + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_EXCHANGE_AUD` ``adbc.flight.sql.oauth.exchange.resource`` - The resource server where the client intends to use the requested security token + The resource server where the client intends to use the requested security token + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_EXCHANGE_RESOURCE` ``adbc.flight.sql.oauth.exchange.scope`` - Specific permissions requested for the new token + Specific permissions requested for the new token + + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_EXCHANGE_SCOPE` ``adbc.flight.sql.oauth.exchange.requested_token_type`` - The type of token the client wants to receive in exchange. - Check list below for supported token types. + The type of token the client wants to receive in exchange. + Check list below for supported token types. + Python: :attr:`adbc_driver_flightsql.DatabaseOptions.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE`, + :class:`adbc_driver_flightsql.OAuthTokenType` Supported token types: - ``urn:ietf:params:oauth:token-type:access_token`` @@ -264,6 +294,8 @@ Supported token types: - ``urn:ietf:params:oauth:token-type:saml2`` - ``urn:ietf:params:oauth:token-type:jwt`` + Python: :class:`adbc_driver_flightsql.OAuthTokenType` + Distributed Result Sets ----------------------- diff --git a/docs/source/python/recipe/flight_sql.rst b/docs/source/python/recipe/flight_sql.rst index 6ceb5845c5..2a083a5cb4 100644 --- a/docs/source/python/recipe/flight_sql.rst +++ b/docs/source/python/recipe/flight_sql.rst @@ -61,3 +61,13 @@ Set the max gRPC message size ----------------------------- .. recipe:: flightsql_sqlite_max_msg_size.py + +Connect with OAuth 2.0 Client Credentials +----------------------------------------- + +.. recipe:: flightsql_oauth_client_credentials.py + +Connect with OAuth 2.0 Token Exchange +------------------------------------- + +.. recipe:: flightsql_oauth_token_exchange.py diff --git a/docs/source/python/recipe/flightsql_oauth_client_credentials.py b/docs/source/python/recipe/flightsql_oauth_client_credentials.py new file mode 100644 index 0000000000..b1559bc1fe --- /dev/null +++ b/docs/source/python/recipe/flightsql_oauth_client_credentials.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# RECIPE STARTS HERE + +#: The Flight SQL driver supports OAuth 2.0 authentication. This example shows +#: how to connect using the Client Credentials flow (RFC 6749), which is +#: suitable for machine-to-machine authentication without user interaction. + +import os + +import adbc_driver_flightsql.dbapi +from adbc_driver_flightsql import DatabaseOptions, OAuthFlowType + +uri = os.environ["ADBC_TEST_FLIGHTSQL_URI"] +token_uri = os.environ["ADBC_OAUTH_TOKEN_URI"] +client_id = os.environ["ADBC_OAUTH_CLIENT_ID"] +client_secret = os.environ["ADBC_OAUTH_CLIENT_SECRET"] + +#: Connect using OAuth 2.0 Client Credentials flow. +#: The driver will automatically obtain and refresh access tokens. + +db_kwargs = { + DatabaseOptions.OAUTH_FLOW.value: OAuthFlowType.CLIENT_CREDENTIALS.value, + DatabaseOptions.OAUTH_TOKEN_URI.value: token_uri, + DatabaseOptions.OAUTH_CLIENT_ID.value: client_id, + DatabaseOptions.OAUTH_CLIENT_SECRET.value: client_secret, + #: Optionally, request specific scopes + # DatabaseOptions.OAUTH_SCOPE.value: "dremio.all", +} + +#: For testing with self-signed certificates, skip TLS verification. +#: In production, you should provide proper TLS certificates. +if os.environ.get("ADBC_OAUTH_SKIP_VERIFY", "true").lower() in ("1", "true"): + db_kwargs[DatabaseOptions.TLS_SKIP_VERIFY.value] = "true" + +conn = adbc_driver_flightsql.dbapi.connect(uri, db_kwargs=db_kwargs) + +#: We can then execute queries as usual. + +with conn.cursor() as cur: + cur.execute("SELECT 1") + + result = cur.fetchone() + print(result) + +conn.close() diff --git a/docs/source/python/recipe/flightsql_oauth_token_exchange.py b/docs/source/python/recipe/flightsql_oauth_token_exchange.py new file mode 100644 index 0000000000..d678f49c2c --- /dev/null +++ b/docs/source/python/recipe/flightsql_oauth_token_exchange.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# RECIPE STARTS HERE + +#: The Flight SQL driver supports OAuth 2.0 Token Exchange (RFC 8693). This +#: allows exchanging an existing token (e.g., a JWT from an identity provider) +#: for a new token that can be used to access the Flight SQL service. + +import os + +import adbc_driver_flightsql.dbapi +from adbc_driver_flightsql import DatabaseOptions, OAuthFlowType, OAuthTokenType + +uri = os.environ["ADBC_TEST_FLIGHTSQL_URI"] +token_uri = os.environ["ADBC_OAUTH_TOKEN_URI"] +#: This is typically a JWT or other token from your identity provider +subject_token = os.environ["ADBC_OAUTH_SUBJECT_TOKEN"] + +#: For testing with self-signed certificates, skip TLS verification. +#: In production, you should provide proper TLS certificates. +db_kwargs = {} +if os.environ.get("ADBC_OAUTH_SKIP_VERIFY", "true").lower() in ("1", "true"): + db_kwargs[DatabaseOptions.TLS_SKIP_VERIFY.value] = "true" + +#: Connect using OAuth 2.0 Token Exchange flow. +#: The driver will exchange the subject token for an access token. + +db_kwargs.update( + { + DatabaseOptions.OAUTH_FLOW.value: OAuthFlowType.TOKEN_EXCHANGE.value, + DatabaseOptions.OAUTH_TOKEN_URI.value: token_uri, + DatabaseOptions.OAUTH_EXCHANGE_SUBJECT_TOKEN.value: subject_token, + #: Specify the type of the subject token being exchanged + DatabaseOptions.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.value: ( + OAuthTokenType.JWT.value + ), + #: Optionally, specify the type of token you want to receive + # DatabaseOptions.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE.value: + # OAuthTokenType.ACCESS_TOKEN.value, + #: Optionally, specify the intended audience + # DatabaseOptions.OAUTH_EXCHANGE_AUD.value: "my-service", + } +) + +conn = adbc_driver_flightsql.dbapi.connect(uri, db_kwargs=db_kwargs) + +#: We can then execute queries as usual. + +with conn.cursor() as cur: + cur.execute("SELECT 1") + + result = cur.fetchone() + print(result) + +conn.close() diff --git a/go/adbc/driver/flightsql/cmd/oauthserver/main.go b/go/adbc/driver/flightsql/cmd/oauthserver/main.go new file mode 100644 index 0000000000..1f854a170d --- /dev/null +++ b/go/adbc/driver/flightsql/cmd/oauthserver/main.go @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// A simple OAuth 2.0 test server supporting Client Credentials (RFC 6749) +// and Token Exchange (RFC 8693) flows for testing ADBC FlightSQL authentication. +package main + +import ( + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "time" +) + +func writeJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(v); err != nil { + log.Printf("Failed to encode JSON response: %v", err) + } +} + +func oauthError(w http.ResponseWriter, code, desc string) { + writeJSON(w, http.StatusBadRequest, map[string]string{ + "error": code, + "error_description": desc, + }) +} + +func issueToken(w http.ResponseWriter, prefix, scope string) { + token := fmt.Sprintf("oauth-%s-token-%d", prefix, time.Now().Unix()) + log.Printf("Issuing %s token: %s", prefix, token) + writeJSON(w, http.StatusOK, map[string]any{ + "access_token": token, + "token_type": "Bearer", + "expires_in": 3600, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": scope, + }) +} + +func main() { + host := flag.String("host", "0.0.0.0", "Host to bind") + port := flag.Int("port", 8181, "Port to bind") + clientID := flag.String("client-id", "test-client", "Expected client ID") + clientSecret := flag.String("client-secret", "test-secret", "Expected client secret") + flag.Parse() + + http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + if err := r.ParseForm(); err != nil { + oauthError(w, "invalid_request", "Failed to parse form data") + return + } + + grantType := r.FormValue("grant_type") + log.Printf("Token request: grant_type=%s", grantType) + + switch grantType { + case "client_credentials": + if r.FormValue("client_id") != *clientID || r.FormValue("client_secret") != *clientSecret { + oauthError(w, "invalid_client", "Invalid client credentials") + return + } + scope := r.FormValue("scope") + if scope == "" { + scope = "dremio.all" + } + issueToken(w, "cc", scope) + + case "urn:ietf:params:oauth:grant-type:token-exchange": + if r.FormValue("subject_token") == "" || r.FormValue("subject_token_type") == "" { + oauthError(w, "invalid_request", "Missing subject_token or subject_token_type") + return + } + issueToken(w, "exchange", r.FormValue("scope")) + + default: + oauthError(w, "unsupported_grant_type", fmt.Sprintf("Grant type '%s' not supported", grantType)) + } + }) + + http.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("OK")) + }) + + addr := fmt.Sprintf("%s:%d", *host, *port) + log.Printf("Starting OAuth test server on %s (client_id=%s)", addr, *clientID) + log.Fatal(http.ListenAndServe(addr, nil)) +} diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go b/go/adbc/driver/flightsql/cmd/testserver/main.go index 6d82bd0c8e..577bf3983a 100644 --- a/go/adbc/driver/flightsql/cmd/testserver/main.go +++ b/go/adbc/driver/flightsql/cmd/testserver/main.go @@ -18,19 +18,30 @@ // A server intended specifically for testing the Flight SQL driver. Unlike // the upstream SQLite example, which tries to be functional, this server // tries to be useful. +// +// Supports optional OAuth authentication and TLS for testing OAuth flows. package main import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/pem" "flag" "fmt" "log" + "math/big" "net" "os" "strconv" "strings" "sync" + "time" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" @@ -38,7 +49,9 @@ import ( "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/apache/arrow-go/v18/arrow/flight/flightsql/schema_ref" "github.com/apache/arrow-go/v18/arrow/memory" + "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" @@ -301,7 +314,7 @@ func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flight ch := make(chan flight.StreamChunk) schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) var rec arrow.RecordBatch - rec, _, err = array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": 5}]`)) + rec, _, err = array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"ints": 5}]`)) go func() { // wait for client cancel <-ctx.Done() @@ -361,7 +374,7 @@ func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flight } schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) - rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": 5}]`)) + rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"ints": 5}]`)) ch := make(chan flight.StreamChunk) go func() { @@ -538,10 +551,136 @@ func (srv *ExampleServer) CloseSession(ctx context.Context, req *flight.CloseSes return &flight.CloseSessionResult{}, nil } +// Hardcoded test credentials for Basic authentication +const ( + testBasicUsername = "user" + testBasicPassword = "password" +) + +// createAuthMiddleware creates gRPC interceptors that validate Bearer tokens or Basic auth. +// If tokenPrefix is empty, no validation is performed (authentication disabled). +// Supports both: +// - Bearer tokens: validated against the tokenPrefix +// - Basic auth: validated against hardcoded test credentials (user:password) +func createAuthMiddleware(tokenPrefix string) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) { + validateAuth := func(ctx context.Context) error { + if tokenPrefix == "" { + return nil // No authentication required + } + + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return status.Error(codes.InvalidArgument, "missing metadata") + } + + auth := md.Get("authorization") + if len(auth) == 0 { + return status.Error(codes.Unauthenticated, "missing authorization header") + } + + authHeader := auth[0] + + // Check for Basic authentication + if strings.HasPrefix(authHeader, "Basic ") { + encoded := strings.TrimPrefix(authHeader, "Basic ") + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + log.Printf("Basic auth decode failed: %v", err) + return status.Error(codes.Unauthenticated, "invalid basic auth encoding") + } + + credentials := string(decoded) + parts := strings.SplitN(credentials, ":", 2) + if len(parts) != 2 { + return status.Error(codes.Unauthenticated, "invalid basic auth format") + } + + username, password := parts[0], parts[1] + if username == testBasicUsername && password == testBasicPassword { + log.Printf("Basic auth validated for user: %s", username) + return nil + } + log.Printf("Basic auth failed: invalid credentials for user: %s", username) + return status.Error(codes.Unauthenticated, "invalid credentials") + } + + // Check for Bearer token authentication + if strings.HasPrefix(authHeader, "Bearer ") { + bearerToken := strings.TrimPrefix(authHeader, "Bearer ") + if !strings.HasPrefix(bearerToken, tokenPrefix) { + log.Printf("Token validation failed: token=%s, expected prefix=%s", bearerToken, tokenPrefix) + return status.Error(codes.Unauthenticated, "invalid token") + } + + log.Printf("Token validated: %s", bearerToken[:min(len(bearerToken), 20)]+"...") + return nil + } + + return status.Error(codes.Unauthenticated, "invalid authorization format, expected 'Bearer ' or 'Basic '") + } + + unary := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if err := validateAuth(ctx); err != nil { + return nil, err + } + return handler(ctx, req) + } + + stream := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if err := validateAuth(ss.Context()); err != nil { + return err + } + return handler(srv, ss) + } + + return unary, stream +} + +// generateSelfSignedCert generates a self-signed TLS certificate for testing +func generateSelfSignedCert() (tls.Certificate, []byte, error) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return tls.Certificate{}, nil, fmt.Errorf("failed to generate private key: %w", err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"ADBC Test Server"}, + CommonName: "localhost", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("0.0.0.0")}, + DNSNames: []string{"localhost"}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, nil, fmt.Errorf("failed to create certificate: %w", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return tls.Certificate{}, nil, fmt.Errorf("failed to create key pair: %w", err) + } + + return cert, certPEM, nil +} + func main() { var ( - host = flag.String("host", "localhost", "hostname to bind to") - port = flag.Int("port", 0, "port to bind to") + host = flag.String("host", "localhost", "hostname to bind to") + port = flag.Int("port", 0, "port to bind to") + useTLS = flag.Bool("tls", false, "Enable TLS with self-signed certificate") + tokenPrefix = flag.String("token-prefix", "", "Required prefix for valid Bearer tokens (empty = no auth)") + certFile = flag.String("cert-file", "", "Path to write the PEM certificate (for client verification)") ) flag.Parse() @@ -552,14 +691,54 @@ func main() { log.Fatal(err) } - server := flight.NewServerWithMiddleware(nil) + // Create middleware (OAuth validation if token-prefix is set) + var middleware []flight.ServerMiddleware + if *tokenPrefix != "" { + unary, stream := createAuthMiddleware(*tokenPrefix) + middleware = append(middleware, flight.ServerMiddleware{Unary: unary, Stream: stream}) + } + + addr := net.JoinHostPort(*host, strconv.Itoa(*port)) + var server flight.Server + + if *useTLS { + cert, certPEM, err := generateSelfSignedCert() + if err != nil { + log.Fatalf("Failed to generate TLS certificate: %v", err) + } + + if *certFile != "" { + if err := os.WriteFile(*certFile, certPEM, 0644); err != nil { + log.Fatalf("Failed to write certificate file: %v", err) + } + log.Printf("Certificate written to %s", *certFile) + } + + tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}} + server = flight.NewServerWithMiddleware(middleware, grpc.Creds(credentials.NewTLS(tlsConfig))) + } else { + server = flight.NewServerWithMiddleware(middleware) + } + server.RegisterFlightService(flightsql.NewFlightServer(srv)) - if err := server.Init(net.JoinHostPort(*host, strconv.Itoa(*port))); err != nil { + if err := server.Init(addr); err != nil { log.Fatal(err) } server.SetShutdownOnSignals(os.Interrupt, os.Kill) - fmt.Println("Starting testing Flight SQL Server on", server.Addr(), "...") + // Build descriptive startup message + features := []string{} + if *useTLS { + features = append(features, "TLS") + } + if *tokenPrefix != "" { + features = append(features, fmt.Sprintf("OAuth(prefix=%s)", *tokenPrefix)) + } + if len(features) > 0 { + fmt.Printf("Starting testing Flight SQL Server on %s with %s...\n", server.Addr(), strings.Join(features, ", ")) + } else { + fmt.Println("Starting testing Flight SQL Server on", server.Addr(), "...") + } if err := server.Serve(); err != nil { log.Fatal(err) diff --git a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py index eeaad22262..14a086b3a5 100644 --- a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py +++ b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py @@ -28,11 +28,57 @@ __all__ = [ "ConnectionOptions", "DatabaseOptions", + "OAuthFlowType", + "OAuthTokenType", "StatementOptions", "connect", ] +class OAuthFlowType(enum.Enum): + """ + OAuth 2.0 flow types supported by the Flight SQL driver. + + Use these values with :attr:`DatabaseOptions.OAUTH_FLOW`. + """ + + #: OAuth 2.0 Client Credentials flow (RFC 6749 Section 4.4). + #: + #: Use when the client application needs to authenticate itself + #: to the authorization server using its own credentials. + CLIENT_CREDENTIALS = "client_credentials" + + #: OAuth 2.0 Token Exchange flow (RFC 8693). + #: + #: Use when the client application wants to exchange one + #: security token for another. + TOKEN_EXCHANGE = "token_exchange" + + +class OAuthTokenType(enum.Enum): + """ + OAuth 2.0 token types supported for token exchange (RFC 8693). + + Use these values with token type options like + :attr:`DatabaseOptions.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE`, + :attr:`DatabaseOptions.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE`, and + :attr:`DatabaseOptions.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE`. + """ + + #: An OAuth 2.0 access token. + ACCESS_TOKEN = "urn:ietf:params:oauth:token-type:access_token" + #: An OAuth 2.0 refresh token. + REFRESH_TOKEN = "urn:ietf:params:oauth:token-type:refresh_token" + #: An OpenID Connect ID token. + ID_TOKEN = "urn:ietf:params:oauth:token-type:id_token" + #: A SAML 1.1 assertion. + SAML1 = "urn:ietf:params:oauth:token-type:saml1" + #: A SAML 2.0 assertion. + SAML2 = "urn:ietf:params:oauth:token-type:saml2" + #: A JSON Web Token (JWT). + JWT = "urn:ietf:params:oauth:token-type:jwt" + + class DatabaseOptions(enum.Enum): """Database options specific to the Flight SQL driver.""" @@ -75,6 +121,59 @@ class DatabaseOptions(enum.Enum): #: Set the maximum gRPC message size (in bytes). The default is 16 MiB. WITH_MAX_MSG_SIZE = "adbc.flight.sql.client_option.with_max_msg_size" + # OAuth 2.0 options + + #: Specifies the OAuth 2.0 flow type to use. + #: + #: See :class:`OAuthFlowType` for possible values. + OAUTH_FLOW = "adbc.flight.sql.oauth.flow" + #: The authorization endpoint URL for OAuth 2.0. + OAUTH_AUTH_URI = "adbc.flight.sql.oauth.auth_uri" + #: The endpoint URL where the client application requests tokens + #: from the authorization server. + OAUTH_TOKEN_URI = "adbc.flight.sql.oauth.token_uri" + #: The redirect URI for OAuth 2.0 flows. + OAUTH_REDIRECT_URI = "adbc.flight.sql.oauth.redirect_uri" + #: Space-separated list of permissions that the client is requesting + #: access to (e.g., ``"read.all offline_access"``). + OAUTH_SCOPE = "adbc.flight.sql.oauth.scope" + #: Unique identifier issued to the client application by the + #: authorization server. + OAUTH_CLIENT_ID = "adbc.flight.sql.oauth.client_id" + #: Secret associated with the client_id. Used to authenticate the + #: client application to the authorization server. + OAUTH_CLIENT_SECRET = "adbc.flight.sql.oauth.client_secret" + + # OAuth 2.0 Token Exchange options (RFC 8693) + + #: The security token that the client application wants to exchange. + OAUTH_EXCHANGE_SUBJECT_TOKEN = "adbc.flight.sql.oauth.exchange.subject_token" + #: Identifier for the type of the subject token. + #: + #: See :class:`OAuthTokenType` for supported token types. + OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE = ( + "adbc.flight.sql.oauth.exchange.subject_token_type" + ) + #: A security token that represents the identity of the acting party. + OAUTH_EXCHANGE_ACTOR_TOKEN = "adbc.flight.sql.oauth.exchange.actor_token" + #: Identifier for the type of the actor token. + #: + #: See :class:`OAuthTokenType` for supported token types. + OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE = "adbc.flight.sql.oauth.exchange.actor_token_type" + #: The type of token the client wants to receive in exchange. + #: + #: See :class:`OAuthTokenType` for supported token types. + OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE = ( + "adbc.flight.sql.oauth.exchange.requested_token_type" + ) + #: Specific permissions requested for the new token in token exchange. + OAUTH_EXCHANGE_SCOPE = "adbc.flight.sql.oauth.exchange.scope" + #: The intended audience for the requested security token in token exchange. + OAUTH_EXCHANGE_AUD = "adbc.flight.sql.oauth.exchange.aud" + #: The resource server where the client intends to use the requested + #: security token in token exchange. + OAUTH_EXCHANGE_RESOURCE = "adbc.flight.sql.oauth.exchange.resource" + class ConnectionOptions(enum.Enum): """Connection options specific to the Flight SQL driver.""" diff --git a/python/adbc_driver_flightsql/tests/conftest.py b/python/adbc_driver_flightsql/tests/conftest.py index 4c775d8e6c..93b6ec358f 100644 --- a/python/adbc_driver_flightsql/tests/conftest.py +++ b/python/adbc_driver_flightsql/tests/conftest.py @@ -80,8 +80,18 @@ def test_dbapi(): if not uri: pytest.skip("Set ADBC_TEST_FLIGHTSQL_URI to run tests") + db_kwargs = { + # Skip TLS verification for self-signed certificates + adbc_driver_flightsql.DatabaseOptions.TLS_SKIP_VERIFY.value: "true", + # Use HTTP Basic authentication (user:password encoded as base64) + adbc_driver_flightsql.DatabaseOptions.AUTHORIZATION_HEADER.value: ( + "Basic dXNlcjpwYXNzd29yZA==" + ), + } + with adbc_driver_flightsql.dbapi.connect( uri, + db_kwargs=db_kwargs, autocommit=True, ) as conn: yield conn