diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index a9b4b75dd..bb9ec70a4 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -29,9 +29,10 @@ services: PGX_TEST_UNIX_SOCKET_CONN_STRING: "host=/var/run/postgresql port=5432 user=postgres dbname=pgx_test" PGX_TEST_TCP_CONN_STRING: "host=127.0.0.1 port=5432 user=pgx_md5 password=secret dbname=pgx_test" PGX_TEST_MD5_PASSWORD_CONN_STRING: "host=127.0.0.1 port=5432 user=pgx_md5 password=secret dbname=pgx_test" - PGX_TEST_SCRAM_PASSWORD_CONN_STRING: "host=127.0.0.1 port=5432 user=pgx_scram password=secret dbname=pgx_test" + PGX_TEST_SCRAM_PASSWORD_CONN_STRING: "host=localhost port=5432 user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" + PGX_TEST_SCRAM_PLUS_CONN_STRING: "host=127.0.0.1 port=5432 user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" PGX_TEST_PLAIN_PASSWORD_CONN_STRING: "host=127.0.0.1 port=5432 user=pgx_pw password=secret dbname=pgx_test" - PGX_TEST_TLS_CONN_STRING: "host=localhost port=5432 user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" + PGX_TEST_TLS_CONN_STRING: "host=localhost port=5432 user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" PGX_TEST_TLS_CLIENT_CONN_STRING: "host=localhost port=5432 user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" PGX_SSL_PASSWORD: certpw diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ce7a012a4..02106aa1d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,50 +20,55 @@ jobs: pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" - pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test" + pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" + pgx-test-scram-plus-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test" - pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" + pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" pgx-ssl-password: certpw pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" - pg-version: 15 pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" - pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test" + pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" + pgx-test-scram-plus-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test" - pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" + pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" pgx-ssl-password: certpw pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" - pg-version: 16 pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" - pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test" + pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" + pgx-test-scram-plus-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test" - pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" + pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" pgx-ssl-password: certpw pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" - pg-version: 17 pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" - pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test" + pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" + pgx-test-scram-plus-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test" - pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" + pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" pgx-ssl-password: certpw pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" - pg-version: 18 pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" - pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test" + pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" + pgx-test-scram-plus-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test" - pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" + pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" pgx-test-oauth: "true" pgx-ssl-password: certpw pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" @@ -119,6 +124,7 @@ jobs: PGX_TEST_OAUTH: ${{ matrix.pgx-test-oauth }} # TestConnectTLS fails. However, it succeeds if I connect to the CI server with upterm and run it. Give up on that test for now. # PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }} + # PGX_TEST_SCRAM_PLUS_CONN_STRING: ${{ matrix.pgx-test-scram-plus-conn-string }} PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }} PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 25dc91826..2283ae670 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -80,10 +80,11 @@ export POSTGRESQL_DATA_DIR=postgresql export PGX_TEST_DATABASE="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" export PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/private/tmp database=pgx_test" export PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" -export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test" +export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test channel_binding=disable" +export PGX_TEST_SCRAM_PLUS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test channel_binding=require" export PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" export PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_pw password=secret" -export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem" +export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem channel_binding=disable" export PGX_SSL_PASSWORD=certpw export PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test sslcert=`pwd`/.testdb/pgx_sslcert.crt sslkey=`pwd`/.testdb/pgx_sslcert.key" ``` diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go index 9979087a0..f59d39c4e 100644 --- a/pgconn/auth_scram.go +++ b/pgconn/auth_scram.go @@ -1,7 +1,8 @@ -// SCRAM-SHA-256 authentication +// SCRAM-SHA-256 and SCRAM-SHA-256-PLUS authentication // // Resources: // https://tools.ietf.org/html/rfc5802 +// https://tools.ietf.org/html/rfc5929 // https://tools.ietf.org/html/rfc8265 // https://www.postgresql.org/docs/current/sasl-authentication.html // @@ -18,9 +19,13 @@ import ( "crypto/pbkdf2" "crypto/rand" "crypto/sha256" + "crypto/sha512" + "crypto/tls" + "crypto/x509" "encoding/base64" "errors" "fmt" + "hash" "slices" "strconv" @@ -28,7 +33,11 @@ import ( "golang.org/x/text/secure/precis" ) -const clientNonceLen = 18 +const ( + clientNonceLen = 18 + scramSHA256Name = "SCRAM-SHA-256" + scramSHA256PlusName = "SCRAM-SHA-256-PLUS" +) // Perform SCRAM authentication. func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { @@ -37,9 +46,35 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { return err } + serverHasPlus := slices.Contains(sc.serverAuthMechanisms, scramSHA256PlusName) + if c.config.ChannelBinding == "require" && !serverHasPlus { + return errors.New("channel binding required but server does not support SCRAM-SHA-256-PLUS") + } + + // If we have a TLS connection and channel binding is not disabled, attempt to + // extract the server certificate hash for tls-server-end-point channel binding. + if tlsConn, ok := c.conn.(*tls.Conn); ok && c.config.ChannelBinding != "disable" { + certHash, err := getTLSCertificateHash(tlsConn) + if err != nil && c.config.ChannelBinding == "require" { + return fmt.Errorf("channel binding required but failed to get server certificate hash: %w", err) + } + + // Upgrade to SCRAM-SHA-256-PLUS if we have binding data and the server supports it. + if certHash != nil && serverHasPlus { + sc.authMechanism = scramSHA256PlusName + } + + sc.channelBindingData = certHash + sc.hasTLS = true + } + + if c.config.ChannelBinding == "require" && sc.channelBindingData == nil { + return errors.New("channel binding required but channel binding data is not available") + } + // Send client-first-message in a SASLInitialResponse saslInitialResponse := &pgproto3.SASLInitialResponse{ - AuthMechanism: "SCRAM-SHA-256", + AuthMechanism: sc.authMechanism, Data: sc.clientFirstMessage(), } c.frontend.Send(saslInitialResponse) @@ -111,7 +146,28 @@ type scramClient struct { password string clientNonce []byte + // authMechanism is the selected SASL mechanism for the client. Must be + // either SCRAM-SHA-256 (default) or SCRAM-SHA-256-PLUS. + // + // Upgraded to SCRAM-SHA-256-PLUS during authentication when channel binding + // is not disabled, channel binding data is available (TLS connection with + // an obtainable server certificate hash) and the server advertises + // SCRAM-SHA-256-PLUS. + authMechanism string + + // hasTLS indicates whether the connection is using TLS. This is + // needed because the GS2 header must distinguish between a client that + // supports channel binding but the server does not ("y,,") versus one + // that does not support it at all ("n,,"). + hasTLS bool + + // channelBindingData is the hash of the server's TLS certificate, computed + // per the tls-server-end-point channel binding type (RFC 5929). Used as + // the binding input in SCRAM-SHA-256-PLUS. nil when not in use. + channelBindingData []byte + clientFirstMessageBare []byte + clientGS2Header []byte serverFirstMessage []byte clientAndServerNonce []byte @@ -125,11 +181,14 @@ type scramClient struct { func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { sc := &scramClient{ serverAuthMechanisms: serverAuthMechanisms, + authMechanism: scramSHA256Name, } - // Ensure server supports SCRAM-SHA-256 - hasScramSHA256 := slices.Contains(sc.serverAuthMechanisms, "SCRAM-SHA-256") - if !hasScramSHA256 { + // Ensure the server supports SCRAM-SHA-256. SCRAM-SHA-256-PLUS is the + // channel binding variant and is only advertised when the server supports + // SSL. PostgreSQL always advertises the base SCRAM-SHA-256 mechanism + // regardless of SSL. + if !slices.Contains(sc.serverAuthMechanisms, scramSHA256Name) { return nil, errors.New("server does not support SCRAM-SHA-256") } @@ -153,8 +212,32 @@ func newScramClient(serverAuthMechanisms []string, password string) (*scramClien } func (sc *scramClient) clientFirstMessage() []byte { + // The client-first-message is the GS2 header concatenated with the bare + // message (username + client nonce). The GS2 header communicates the + // client's channel binding capability to the server: + // + // "n,," - client is not using TLS (channel binding not possible) + // "y,," - client is using TLS but channel binding is not + // in use (e.g., server did not advertise SCRAM-SHA-256-PLUS + // or the server certificate hash was not obtainable) + // "p=tls-server-end-point,," - channel binding is active via SCRAM-SHA-256-PLUS + // + // See: + // https://www.rfc-editor.org/rfc/rfc5802#section-6 + // https://www.rfc-editor.org/rfc/rfc5929#section-4 + // https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256 + sc.clientFirstMessageBare = fmt.Appendf(nil, "n=,r=%s", sc.clientNonce) - return fmt.Appendf(nil, "n,,%s", sc.clientFirstMessageBare) + + if sc.authMechanism == scramSHA256PlusName { + sc.clientGS2Header = []byte("p=tls-server-end-point,,") + } else if sc.hasTLS { + sc.clientGS2Header = []byte("y,,") + } else { + sc.clientGS2Header = []byte("n,,") + } + + return append(sc.clientGS2Header, sc.clientFirstMessageBare...) } func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { @@ -213,7 +296,19 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { } func (sc *scramClient) clientFinalMessage() string { - clientFinalMessageWithoutProof := fmt.Appendf(nil, "c=biws,r=%s", sc.clientAndServerNonce) + // The c= attribute carries the base64-encoded channel binding input. + // + // Without channel binding this is just the GS2 header alone ("biws" for + // "n,," or "eSws" for "y,,"). + // + // With channel binding, this is the GS2 header with the channel binding data + // (certificate hash) appended. + channelBindInput := sc.clientGS2Header + if sc.authMechanism == scramSHA256PlusName { + channelBindInput = slices.Concat(sc.clientGS2Header, sc.channelBindingData) + } + channelBindingEncoded := base64.StdEncoding.EncodeToString(channelBindInput) + clientFinalMessageWithoutProof := fmt.Appendf(nil, "c=%s,r=%s", channelBindingEncoded, sc.clientAndServerNonce) var err error sc.saltedPassword, err = pbkdf2.Key(sha256.New, sc.password, sc.salt, sc.iterations, 32) @@ -269,3 +364,36 @@ func computeServerSignature(saltedPassword, authMessage []byte) []byte { base64.StdEncoding.Encode(buf, serverSignature) return buf } + +// Get the server certificate hash for SCRAM channel binding type +// tls-server-end-point. +func getTLSCertificateHash(conn *tls.Conn) ([]byte, error) { + state := conn.ConnectionState() + if len(state.PeerCertificates) == 0 { + return nil, errors.New("no peer certificates for channel binding") + } + + cert := state.PeerCertificates[0] + + // Per RFC 5929 section 4.1: If the certificate's signatureAlgorithm uses + // MD5 or SHA-1, use SHA-256. Otherwise use the hash from the signature + // algorithm. + // + // See: https://www.rfc-editor.org/rfc/rfc5929.html#section-4.1 + var h hash.Hash + switch cert.SignatureAlgorithm { + case x509.MD5WithRSA, x509.SHA1WithRSA, x509.ECDSAWithSHA1: + h = sha256.New() + case x509.SHA256WithRSA, x509.SHA256WithRSAPSS, x509.ECDSAWithSHA256: + h = sha256.New() + case x509.SHA384WithRSA, x509.SHA384WithRSAPSS, x509.ECDSAWithSHA384: + h = sha512.New384() + case x509.SHA512WithRSA, x509.SHA512WithRSAPSS, x509.ECDSAWithSHA512: + h = sha512.New() + default: + return nil, fmt.Errorf("tls-server-end-point channel binding is undefined for certificate signature algorithm %v", cert.SignatureAlgorithm) + } + + h.Write(cert.Raw) + return h.Sum(nil), nil +} diff --git a/pgconn/auth_scram_test.go b/pgconn/auth_scram_test.go new file mode 100644 index 000000000..dcb500f8f --- /dev/null +++ b/pgconn/auth_scram_test.go @@ -0,0 +1,436 @@ +package pgconn + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/sha512" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "fmt" + "math/big" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func generateSelfSignedCert(t *testing.T, sigAlg x509.SignatureAlgorithm) tls.Certificate { + t.Helper() + + var curve elliptic.Curve + switch sigAlg { + case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256: + curve = elliptic.P256() + case x509.ECDSAWithSHA384: + curve = elliptic.P384() + case x509.ECDSAWithSHA512: + curve = elliptic.P521() + default: + t.Fatalf("unsupported signature algorithm: %v", sigAlg) + } + + key, err := ecdsa.GenerateKey(curve, rand.Reader) + if err != nil { + t.Fatal(err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + SignatureAlgorithm: sigAlg, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatal(err) + } + + cert, err := x509.ParseCertificate(certDER) + if err != nil { + t.Fatal(err) + } + + return tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: key, + Leaf: cert, + } +} + +// tlsConnWithCert performs a TLS handshake over a net.Pipe using the given +// certificate and returns the client-side *tls.Conn with peer certificates +// populated. +func tlsConnWithCert(t *testing.T, cert tls.Certificate) *tls.Conn { + t.Helper() + + clientConn, serverConn := net.Pipe() + + t.Cleanup(func() { + clientConn.Close() + serverConn.Close() + }) + + tlsServer := tls.Server(serverConn, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + + tlsClient := tls.Client(clientConn, &tls.Config{ + InsecureSkipVerify: true, + }) + + errChan := make(chan error, 1) + go func() { errChan <- tlsServer.Handshake() }() + + require.NoError(t, tlsClient.Handshake()) + require.NoError(t, <-errChan) + + return tlsClient +} + +func TestGetTLSCertificateHash(t *testing.T) { + t.Parallel() + + t.Run("SHA1", func(t *testing.T) { + t.Parallel() + + // Per RFC 5929 section 4.1: SHA-1 signed certs use SHA-256 for the hash. + cert := generateSelfSignedCert(t, x509.ECDSAWithSHA1) + tlsConn := tlsConnWithCert(t, cert) + + hash, err := getTLSCertificateHash(tlsConn) + require.NoError(t, err) + require.Len(t, hash, sha256.Size) + }) + + t.Run("SHA256", func(t *testing.T) { + t.Parallel() + + cert := generateSelfSignedCert(t, x509.ECDSAWithSHA256) + tlsConn := tlsConnWithCert(t, cert) + + hash, err := getTLSCertificateHash(tlsConn) + require.NoError(t, err) + require.Len(t, hash, sha256.Size) + }) + + t.Run("SHA384", func(t *testing.T) { + t.Parallel() + + cert := generateSelfSignedCert(t, x509.ECDSAWithSHA384) + tlsConn := tlsConnWithCert(t, cert) + + hash, err := getTLSCertificateHash(tlsConn) + require.NoError(t, err) + require.Len(t, hash, sha512.Size384) + }) + + t.Run("SHA512", func(t *testing.T) { + t.Parallel() + + cert := generateSelfSignedCert(t, x509.ECDSAWithSHA512) + tlsConn := tlsConnWithCert(t, cert) + + hash, err := getTLSCertificateHash(tlsConn) + require.NoError(t, err) + require.Len(t, hash, sha512.Size) + }) +} + +func TestScramClientFirstMessage(t *testing.T) { + t.Parallel() + + t.Run("ChannelBindingNotSupported", func(t *testing.T) { + t.Parallel() + + client, err := newScramClient([]string{scramSHA256Name}, "secret") + require.NoError(t, err) + + firstMessage := client.clientFirstMessage() + + require.True(t, bytes.HasPrefix(firstMessage, []byte("n,,"))) + require.True(t, bytes.HasSuffix(firstMessage, client.clientNonce)) + }) + + t.Run("ChannelBindingClientSupported", func(t *testing.T) { + t.Parallel() + + client, err := newScramClient([]string{scramSHA256Name}, "secret") + require.NoError(t, err) + + client.authMechanism = scramSHA256Name + client.hasTLS = true + client.channelBindingData = []byte{1, 2, 3} + + firstMessage := client.clientFirstMessage() + require.True(t, bytes.HasPrefix(firstMessage, []byte("y,,"))) + }) + + t.Run("ChannelBindingTLSWithoutCertHash", func(t *testing.T) { + t.Parallel() + + // When on TLS but cert hash is unavailable (e.g., unsupported signature + // algorithm), the client should still send "y,," to enable downgrade + // detection per RFC 5802. + client, err := newScramClient([]string{scramSHA256Name}, "secret") + require.NoError(t, err) + + client.authMechanism = scramSHA256Name + client.hasTLS = true + client.channelBindingData = nil + + firstMessage := client.clientFirstMessage() + require.True(t, bytes.HasPrefix(firstMessage, []byte("y,,"))) + }) + + t.Run("ChannelBindingActive", func(t *testing.T) { + t.Parallel() + + client, err := newScramClient([]string{scramSHA256Name, scramSHA256PlusName}, "secret") + require.NoError(t, err) + + client.authMechanism = scramSHA256PlusName + client.channelBindingData = []byte{1, 2, 3} + + firstMessage := client.clientFirstMessage() + require.True(t, bytes.HasPrefix(firstMessage, []byte("p=tls-server-end-point,,"))) + }) +} + +func TestScramClientFinalMessage(t *testing.T) { + t.Parallel() + + setup := func(t *testing.T) *scramClient { + t.Helper() + + return &scramClient{ + clientNonce: []byte("testnonce"), + password: "secret", + authMechanism: scramSHA256Name, + } + } + + // withServerChallenge advances the scramClient through the client-first + // and server-first (challenge) messages, leaving it ready to produce the + // client-final-message. + withServerChallenge := func(t *testing.T, sc *scramClient) { + t.Helper() + + sc.clientFirstMessage() + + serverNonce := string(sc.clientNonce) + "servernonce" + salt := base64.StdEncoding.EncodeToString([]byte("testsalt")) + serverFirstMsg := fmt.Sprintf("r=%s,s=%s,i=4096", serverNonce, salt) + require.NoError(t, sc.recvServerFirstMessage([]byte(serverFirstMsg))) + } + + t.Run("ChannelBindingNone", func(t *testing.T) { + t.Parallel() + + sc := setup(t) + withServerChallenge(t, sc) + + msg := sc.clientFinalMessage() + + expected := base64.StdEncoding.EncodeToString([]byte("n,,")) + require.Contains(t, msg, "c="+expected) + }) + + t.Run("ChannelBindingClientSupports", func(t *testing.T) { + t.Parallel() + + sc := setup(t) + sc.hasTLS = true + sc.channelBindingData = []byte{1, 2, 3} + + withServerChallenge(t, sc) + + msg := sc.clientFinalMessage() + + expected := base64.StdEncoding.EncodeToString([]byte("y,,")) + require.Contains(t, msg, "c="+expected) + }) + + t.Run("ChannelBindingActive", func(t *testing.T) { + t.Parallel() + + sc := setup(t) + sc.authMechanism = scramSHA256PlusName + sc.hasTLS = true + sc.channelBindingData = []byte{1, 2, 3} + + withServerChallenge(t, sc) + + msg := sc.clientFinalMessage() + + expected := base64.StdEncoding.EncodeToString(append([]byte("p=tls-server-end-point,,"), 0x01, 0x02, 0x03)) + require.Contains(t, msg, "c="+expected) + }) +} + +func TestScramClientMechanismValidation(t *testing.T) { + t.Parallel() + + // Server does not support SSL. + _, err := newScramClient([]string{scramSHA256Name}, "password") + require.NoError(t, err) + + // Server supports SSL. + _, err = newScramClient([]string{scramSHA256PlusName, scramSHA256Name}, "password") + require.NoError(t, err) + + // Invalid. + _, err = newScramClient([]string{"MD5"}, "password") + require.Error(t, err) +} + +func TestScramClientRecvServerFirstMessage(t *testing.T) { + t.Parallel() + + clientNonce := "testnonce" + serverNonce := clientNonce + "servernonce" + salt := "testsalt" + saltEncoded := base64.StdEncoding.EncodeToString([]byte(salt)) + + t.Run("Valid", func(t *testing.T) { + t.Parallel() + + // SCRAM server-first-message has the form: r=,s=,i= + validMsg := fmt.Sprintf("r=%s,s=%s,i=%d", serverNonce, saltEncoded, 4096) + + sc := &scramClient{clientNonce: []byte(clientNonce)} + err := sc.recvServerFirstMessage([]byte(validMsg)) + require.NoError(t, err) + + require.Equal(t, []byte(serverNonce), sc.clientAndServerNonce) + require.Equal(t, []byte(salt), sc.salt) + require.Equal(t, 4096, sc.iterations) + }) + + t.Run("Invalid", func(t *testing.T) { + t.Parallel() + sc := &scramClient{clientNonce: []byte(clientNonce)} + + // Missing nonce. + { + err := sc.recvServerFirstMessage([]byte("s=" + saltEncoded + ",i=4096")) + require.Error(t, err) + require.Contains(t, err.Error(), "did not include r=") + } + + // Missing salt. + { + err := sc.recvServerFirstMessage([]byte("r=" + serverNonce + ",i=4096")) + require.Error(t, err) + require.Contains(t, err.Error(), "did not include s=") + } + + // Missing iterations. + { + err := sc.recvServerFirstMessage([]byte("r=" + serverNonce + ",s=" + saltEncoded)) + require.Error(t, err) + require.Contains(t, err.Error(), "did not include i=") + } + + // Invalid salt encoding. + { + err := sc.recvServerFirstMessage([]byte("r=" + serverNonce + ",s=%%%invalid,i=4096")) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid SCRAM salt") + } + + // Non-numeric iteration count. + { + err := sc.recvServerFirstMessage([]byte("r=" + serverNonce + ",s=" + saltEncoded + ",i=notanumber")) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid SCRAM iteration count") + } + + // Zero iteration count. + { + err := sc.recvServerFirstMessage([]byte("r=" + serverNonce + ",s=" + saltEncoded + ",i=0")) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid SCRAM iteration count") + } + + // Nonce missing client prefix. + { + err := sc.recvServerFirstMessage([]byte("r=wrongnonce,s=" + saltEncoded + ",i=4096")) + require.Error(t, err) + require.Contains(t, err.Error(), "did not start with client nonce") + } + + // Nonce without server contribution. + { + err := sc.recvServerFirstMessage([]byte("r=" + clientNonce + ",s=" + saltEncoded + ",i=4096")) + require.Error(t, err) + require.Contains(t, err.Error(), "did not include server nonce") + } + }) +} + +func TestScramClientRecvServerFinalMessage(t *testing.T) { + t.Parallel() + + setup := func(t *testing.T) *scramClient { + t.Helper() + + // Build a scramClient that has completed the full message exchange up + // through clientFinalMessage, ready to receive server-final-message. + sc := &scramClient{ + clientNonce: []byte("testnonce"), + authMechanism: scramSHA256Name, + password: "secret", + } + sc.clientFirstMessage() + + serverNonce := string(sc.clientNonce) + "servernonce" + salt := base64.StdEncoding.EncodeToString([]byte("testsalt")) + serverFirstMsg := fmt.Sprintf("r=%s,s=%s,i=4096", serverNonce, salt) + require.NoError(t, sc.recvServerFirstMessage([]byte(serverFirstMsg))) + + sc.clientFinalMessage() + + return sc + } + + t.Run("Valid", func(t *testing.T) { + t.Parallel() + + sc := setup(t) + + validSignature := computeServerSignature(sc.saltedPassword, sc.authMessage) + err := sc.recvServerFinalMessage(append([]byte("v="), validSignature...)) + require.NoError(t, err) + }) + + t.Run("Invalid", func(t *testing.T) { + t.Parallel() + + sc := setup(t) + + // Missing server signature attribute. + { + err := sc.recvServerFinalMessage([]byte("e=some-error")) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid SCRAM server-final-message") + } + + // Invalid server signature. + { + wrongSig := base64.StdEncoding.EncodeToString([]byte("wrong")) + err := sc.recvServerFinalMessage([]byte("v=" + wrongSig)) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid SCRAM ServerSignature") + } + }) +} diff --git a/pgconn/config.go b/pgconn/config.go index 648277d2d..dff550953 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -96,6 +96,10 @@ type Config struct { // Valid values: "3.0", "3.2", "latest". Defaults to "3.0" for compatibility. MaxProtocolVersion string + // ChannelBinding is the channel_binding parameter for SCRAM-SHA-256-PLUS authentication. + // Valid values: "disable", "prefer", "require". Defaults to "prefer". + ChannelBinding string + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -355,6 +359,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con "servicefile": {}, "min_protocol_version": {}, "max_protocol_version": {}, + "channel_binding": {}, } // Adding kerberos configuration @@ -468,6 +473,17 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con config.MaxProtocolVersion = "3.0" } + switch channelBinding := settings["channel_binding"]; channelBinding { + case "", "prefer": + config.ChannelBinding = "prefer" + case "disable": + config.ChannelBinding = "disable" + case "require": + config.ChannelBinding = "require" + default: + return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown channel_binding value: %v", channelBinding)} + } + return config, nil } diff --git a/pgconn/config_test.go b/pgconn/config_test.go index 2a2b1e284..57898eb88 100644 --- a/pgconn/config_test.go +++ b/pgconn/config_test.go @@ -1314,3 +1314,57 @@ func TestParseConfigProtocolVersion(t *testing.T) { }) } } + +func TestParseConfigChannelBinding(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + connString string + expected string + expectError bool + expectedErrText string + }{ + { + name: "defaults to prefer", + connString: "postgres://localhost/test", + expected: "prefer", + }, + { + name: "explicit prefer", + connString: "postgres://localhost/test?channel_binding=prefer", + expected: "prefer", + }, + { + name: "disable", + connString: "postgres://localhost/test?channel_binding=disable", + expected: "disable", + }, + { + name: "require", + connString: "postgres://localhost/test?channel_binding=require", + expected: "require", + }, + { + name: "invalid value", + connString: "postgres://localhost/test?channel_binding=invalid", + expectError: true, + expectedErrText: "unknown channel_binding value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(tt.connString) + if tt.expectError { + require.ErrorContains(t, err, tt.expectedErrText) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expected, config.ChannelBinding) + }) + } +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index a0debce25..92d2d370e 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -102,24 +102,154 @@ func TestConnectWithOptions(t *testing.T) { func TestConnectTLS(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - defer cancel() + setup := func(t *testing.T, connString string) (*pgconn.PgConn, context.Context) { + t.Helper() - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + t.Cleanup(cancel) + + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) + t.Cleanup(func() { closeConn(t, conn) }) + + return conn, ctx } - conn, err := pgconn.Connect(ctx, connString) - require.NoError(t, err) + t.Run("WithChannelBinding", func(t *testing.T) { + t.Parallel() - result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, 1) - require.Len(t, result.Rows[0], 1) - require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection") + connString := os.Getenv("PGX_TEST_SCRAM_PLUS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_SCRAM_PLUS_CONN_STRING") + } - closeConn(t, conn) + conn, ctx := setup(t, connString) + + result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, 1) + require.Len(t, result.Rows[0], 1) + require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection") + }) + + t.Run("WithoutChannelBinding", func(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } + + conn, ctx := setup(t, connString) + + result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, 1) + require.Len(t, result.Rows[0], 1) + require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection") + }) +} + +func TestConnectChannelBinding(t *testing.T) { + t.Parallel() + + t.Run("RequireFailsWithoutTLS", func(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: []pgmock.Step{ + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersion30, Parameters: map[string]string{}}), + pgmock.SendMessage(&pgproto3.AuthenticationSASL{AuthMechanisms: []string{"SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"}}), + pgmock.WaitForClose(), + }, + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Second * 5)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(conn, conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + host, port, _ := strings.Cut(ln.Addr().String(), ":") + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s channel_binding=require", host, port) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, err = pgconn.Connect(ctx, connStr) + require.ErrorContains(t, err, "channel binding required but channel binding data is not available") + }) + + t.Run("RequireFailsWithoutServerPlus", func(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: []pgmock.Step{ + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersion30, Parameters: map[string]string{}}), + pgmock.SendMessage(&pgproto3.AuthenticationSASL{AuthMechanisms: []string{"SCRAM-SHA-256"}}), + pgmock.WaitForClose(), + }, + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Second * 5)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(conn, conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + host, port, _ := strings.Cut(ln.Addr().String(), ":") + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s channel_binding=require", host, port) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, err = pgconn.Connect(ctx, connStr) + require.ErrorContains(t, err, "channel binding required but server does not support SCRAM-SHA-256-PLUS") + }) } // TestConnectOAuth is separate from other connect tests because it specifically diff --git a/test.sh b/test.sh index 48e81fcbe..8bab2d280 100755 --- a/test.sh +++ b/test.sh @@ -121,9 +121,10 @@ run_tests() { PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql port=$port user=postgres dbname=pgx_test" \ PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \ PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \ - PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_scram password=secret dbname=pgx_test" \ + PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" \ + PGX_TEST_SCRAM_PLUS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" \ PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_pw password=secret dbname=pgx_test" \ - PGX_TEST_TLS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" \ + PGX_TEST_TLS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" \ PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost port=$port user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" \ PGX_SSL_PASSWORD=certpw \ go test -count=1 "${extra_args[@]}" ./...; then