Skip to content

Commit 57f291d

Browse files
feat: support sslnegotiation flag (#1180)
1 parent 852fb3b commit 57f291d

File tree

5 files changed

+65
-31
lines changed

5 files changed

+65
-31
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ newer. Previously PostgreSQL 8.4 and newer were supported.
77

88
### Features
99

10-
- Add support for NamedValueChecker interface ([#1125])
10+
- Add support for NamedValueChecker interface ([#1125]).
11+
12+
- Support [`sslnegotiation`] to use SSL without negotiation ([#1180]).
1113

1214
- The `pq.Error.ErrorWithDetail()` method prints a more detailed multiline
1315
message, with the Detail, Hint, and error position (if any) ([#1219]):
@@ -55,6 +57,7 @@ newer. Previously PostgreSQL 8.4 and newer were supported.
5557

5658
- Treat nil []byte in query parameters as nil/NULL rather than `""` ([#838]).
5759

60+
[`sslnegotiation`]: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLNEGOTIATION
5861
[#595]: https://github.com/lib/pq/pull/595
5962
[#745]: https://github.com/lib/pq/pull/745
6063
[#743]: https://github.com/lib/pq/pull/743
@@ -68,6 +71,7 @@ newer. Previously PostgreSQL 8.4 and newer were supported.
6871
[#1161]: https://github.com/lib/pq/pull/1161
6972
[#1166]: https://github.com/lib/pq/pull/1166
7073
[#1179]: https://github.com/lib/pq/pull/1179
74+
[#1180]: https://github.com/lib/pq/pull/1180
7175
[#1184]: https://github.com/lib/pq/pull/1184
7276
[#1211]: https://github.com/lib/pq/pull/1211
7377
[#1212]: https://github.com/lib/pq/pull/1212

conn.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,20 +1184,24 @@ func (cn *conn) ssl(o values) error {
11841184
return nil
11851185
}
11861186

1187-
w := cn.writeBuf(0)
1188-
w.int32(80877103)
1189-
if err = cn.sendStartupPacket(w); err != nil {
1190-
return err
1191-
}
1187+
// only negotiate the ssl handshake if requested (which is the default).
1188+
// sllnegotiation=direct is supported by pg17 and above.
1189+
if sslnegotiation(o) {
1190+
w := cn.writeBuf(0)
1191+
w.int32(80877103)
1192+
if err = cn.sendStartupPacket(w); err != nil {
1193+
return err
1194+
}
11921195

1193-
b := cn.scratch[:1]
1194-
_, err = io.ReadFull(cn.c, b)
1195-
if err != nil {
1196-
return err
1197-
}
1196+
b := cn.scratch[:1]
1197+
_, err = io.ReadFull(cn.c, b)
1198+
if err != nil {
1199+
return err
1200+
}
11981201

1199-
if b[0] != 'S' {
1200-
return ErrSSLNotSupported
1202+
if b[0] != 'S' {
1203+
return ErrSSLNotSupported
1204+
}
12011205
}
12021206

12031207
cn.c, err = upgrade(cn.c)

doc.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ supported:
5555
- sslkey - Key file location. The file must contain PEM encoded data.
5656
- sslrootcert - The location of the root certificate file. The file
5757
must contain PEM encoded data.
58+
- sslnegotiation - when set to "direct" it will use SSL without negotiation (PostgreSQL ≥17 only).
5859
5960
Valid values for sslmode are:
6061

ssl.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,12 @@ func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error
211211
_, err = certs[0].Verify(opts)
212212
return err
213213
}
214+
215+
// sslnegotiation returns true if we should negotiate SSL.
216+
// returns false if there should be no negotiation and we should upgrade immediately.
217+
func sslnegotiation(o values) bool {
218+
if v, ok := o["sslnegotiation"]; ok && v == "direct" {
219+
return false
220+
}
221+
return true
222+
}

ssl_test.go

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ func TestSNISupport(t *testing.T) {
324324
conn_param string
325325
hostname string
326326
expected_sni string
327+
direct bool
327328
}{
328329
{
329330
name: "SNI is set by default",
@@ -349,6 +350,19 @@ func TestSNISupport(t *testing.T) {
349350
hostname: "127.0.0.1",
350351
expected_sni: "",
351352
},
353+
{
354+
name: "SNI is set for negotiated ssl",
355+
conn_param: "sslnegotiation=postgres",
356+
hostname: "localhost",
357+
expected_sni: "localhost",
358+
},
359+
{
360+
name: "SNI is set for direct ssl",
361+
conn_param: "sslnegotiation=direct",
362+
hostname: "localhost",
363+
expected_sni: "localhost",
364+
direct: true,
365+
},
352366
}
353367
for _, tt := range tests {
354368
tt := tt
@@ -362,7 +376,7 @@ func TestSNISupport(t *testing.T) {
362376
}
363377
serverErrChan := make(chan error, 1)
364378
serverSNINameChan := make(chan string, 1)
365-
go mockPostgresSSL(listener, serverErrChan, serverSNINameChan)
379+
go mockPostgresSSL(listener, tt.direct, serverErrChan, serverSNINameChan)
366380

367381
defer listener.Close()
368382
defer close(serverErrChan)
@@ -397,7 +411,7 @@ func TestSNISupport(t *testing.T) {
397411
//
398412
// Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.
399413
// While reading clientHello catch passed SNI data and report it to nameChan.
400-
func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) {
414+
func mockPostgresSSL(listener net.Listener, direct bool, errChan chan error, nameChan chan string) {
401415
var sniHost string
402416

403417
conn, err := listener.Accept()
@@ -413,23 +427,25 @@ func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan st
413427
return
414428
}
415429

416-
// Receive StartupMessage with SSL Request
417-
startupMessage := make([]byte, 8)
418-
if _, err := io.ReadFull(conn, startupMessage); err != nil {
419-
errChan <- err
420-
return
421-
}
422-
// StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
423-
if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) {
424-
errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
425-
return
426-
}
430+
if !direct {
431+
// Receive StartupMessage with SSL Request
432+
startupMessage := make([]byte, 8)
433+
if _, err := io.ReadFull(conn, startupMessage); err != nil {
434+
errChan <- err
435+
return
436+
}
437+
// StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
438+
if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) {
439+
errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
440+
return
441+
}
427442

428-
// Respond with SSLOk
429-
_, err = conn.Write([]byte("S"))
430-
if err != nil {
431-
errChan <- err
432-
return
443+
// Respond with SSLOk
444+
_, err = conn.Write([]byte("S"))
445+
if err != nil {
446+
errChan <- err
447+
return
448+
}
433449
}
434450

435451
// Set up TLS context to catch clientHello. It will always error out during handshake

0 commit comments

Comments
 (0)