@@ -324,6 +324,7 @@ func TestSNISupport(t *testing.T) {
324324 conn_param string
325325 hostname string
326326 expected_sni string
327+ direct bool
327328 }{
328329 {
329330 name : "SNI is set by default" ,
@@ -349,6 +350,19 @@ func TestSNISupport(t *testing.T) {
349350 hostname : "127.0.0.1" ,
350351 expected_sni : "" ,
351352 },
353+ {
354+ name : "SNI is set for negotiated ssl" ,
355+ conn_param : "sslnegotiation=postgres" ,
356+ hostname : "localhost" ,
357+ expected_sni : "localhost" ,
358+ },
359+ {
360+ name : "SNI is set for direct ssl" ,
361+ conn_param : "sslnegotiation=direct" ,
362+ hostname : "localhost" ,
363+ expected_sni : "localhost" ,
364+ direct : true ,
365+ },
352366 }
353367 for _ , tt := range tests {
354368 tt := tt
@@ -362,7 +376,7 @@ func TestSNISupport(t *testing.T) {
362376 }
363377 serverErrChan := make (chan error , 1 )
364378 serverSNINameChan := make (chan string , 1 )
365- go mockPostgresSSL (listener , serverErrChan , serverSNINameChan )
379+ go mockPostgresSSL (listener , tt . direct , serverErrChan , serverSNINameChan )
366380
367381 defer listener .Close ()
368382 defer close (serverErrChan )
@@ -397,7 +411,7 @@ func TestSNISupport(t *testing.T) {
397411//
398412// Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.
399413// While reading clientHello catch passed SNI data and report it to nameChan.
400- func mockPostgresSSL (listener net.Listener , errChan chan error , nameChan chan string ) {
414+ func mockPostgresSSL (listener net.Listener , direct bool , errChan chan error , nameChan chan string ) {
401415 var sniHost string
402416
403417 conn , err := listener .Accept ()
@@ -413,23 +427,25 @@ func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan st
413427 return
414428 }
415429
416- // Receive StartupMessage with SSL Request
417- startupMessage := make ([]byte , 8 )
418- if _ , err := io .ReadFull (conn , startupMessage ); err != nil {
419- errChan <- err
420- return
421- }
422- // StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
423- if ! bytes .Equal (startupMessage , []byte {0 , 0 , 0 , 0x8 , 0x4 , 0xd2 , 0x16 , 0x2f }) {
424- errChan <- fmt .Errorf ("unexpected startup message: %#v" , startupMessage )
425- return
426- }
430+ if ! direct {
431+ // Receive StartupMessage with SSL Request
432+ startupMessage := make ([]byte , 8 )
433+ if _ , err := io .ReadFull (conn , startupMessage ); err != nil {
434+ errChan <- err
435+ return
436+ }
437+ // StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
438+ if ! bytes .Equal (startupMessage , []byte {0 , 0 , 0 , 0x8 , 0x4 , 0xd2 , 0x16 , 0x2f }) {
439+ errChan <- fmt .Errorf ("unexpected startup message: %#v" , startupMessage )
440+ return
441+ }
427442
428- // Respond with SSLOk
429- _ , err = conn .Write ([]byte ("S" ))
430- if err != nil {
431- errChan <- err
432- return
443+ // Respond with SSLOk
444+ _ , err = conn .Write ([]byte ("S" ))
445+ if err != nil {
446+ errChan <- err
447+ return
448+ }
433449 }
434450
435451 // Set up TLS context to catch clientHello. It will always error out during handshake
0 commit comments