Skip to content

Commit ebc13f5

Browse files
authored
Test authentication (#1222)
Add some basic tests for authentication, as that was previously untested. GSS/Kerberos is not yet tested. Print a nicer "pq: unsupported authentication method: SSPI (9)" when people try to use SSPI. "pq: unknown authentication response: 9" was rather confusing. Add a new internal/proto package with constants for the PostgreSQL protocol. Also add pqtest.Exec() and pqtest.Query() to make writing tests a bit easier. Not used in this PR, but wrote them for debugging some other issues and might as well commit it here.
1 parent 57f291d commit ebc13f5

File tree

12 files changed

+388
-159
lines changed

12 files changed

+388
-159
lines changed

conn.go

Lines changed: 74 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"time"
2626
"unicode"
2727

28+
"github.com/lib/pq/internal/proto"
2829
"github.com/lib/pq/oid"
2930
"github.com/lib/pq/scram"
3031
)
@@ -1262,15 +1263,17 @@ func (cn *conn) startup(o values) {
12621263
}
12631264

12641265
for {
1265-
t, r := cn.recv()
1266-
switch t {
1267-
case 'K':
1266+
switch t, r := cn.recv(); proto.ResponseCode(t) {
1267+
case proto.BackendKeyData:
12681268
cn.processBackendKeyData(r)
1269-
case 'S':
1269+
case proto.ParameterStatus:
12701270
cn.processParameterStatus(r)
1271-
case 'R':
1272-
cn.auth(r, o)
1273-
case 'Z':
1271+
case proto.AuthenticationRequest:
1272+
err := cn.auth(r, o)
1273+
if err != nil {
1274+
panic(err)
1275+
}
1276+
case proto.ReadyForQuery:
12741277
cn.processReadyForQuery(r)
12751278
return
12761279
default:
@@ -1279,48 +1282,55 @@ func (cn *conn) startup(o values) {
12791282
}
12801283
}
12811284

1282-
func (cn *conn) auth(r *readBuf, o values) {
1283-
switch code := r.int32(); code {
1284-
case 0:
1285-
// OK
1286-
case 3:
1287-
w := cn.writeBuf('p')
1285+
func (cn *conn) auth(r *readBuf, o values) error {
1286+
switch code := proto.AuthCode(r.int32()); code {
1287+
default:
1288+
return fmt.Errorf("pq: unknown authentication response: %s", code)
1289+
case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI:
1290+
return fmt.Errorf("pq: unsupported authentication method: %s", code)
1291+
1292+
case proto.AuthReqOk:
1293+
return nil
1294+
1295+
case proto.AuthReqPassword:
1296+
w := cn.writeBuf(byte(proto.PasswordMessage))
12881297
w.string(o["password"])
12891298
cn.send(w)
12901299

12911300
t, r := cn.recv()
1292-
if t != 'R' {
1293-
errorf("unexpected password response: %q", t)
1301+
if t != byte(proto.AuthenticationRequest) {
1302+
return fmt.Errorf("pq: unexpected password response: %q", t)
12941303
}
1295-
1296-
if r.int32() != 0 {
1297-
errorf("unexpected authentication response: %q", t)
1304+
if r.int32() != int(proto.AuthReqOk) {
1305+
return fmt.Errorf("pq: unexpected authentication response: %q", t)
12981306
}
1299-
case 5:
1307+
return nil
1308+
1309+
case proto.AuthReqMD5:
13001310
s := string(r.next(4))
1301-
w := cn.writeBuf('p')
1311+
w := cn.writeBuf(byte(proto.PasswordMessage))
13021312
w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
13031313
cn.send(w)
13041314

13051315
t, r := cn.recv()
1306-
if t != 'R' {
1307-
errorf("unexpected password response: %q", t)
1316+
if t != byte(proto.AuthenticationRequest) {
1317+
return fmt.Errorf("pq: unexpected password response: %q", t)
13081318
}
1309-
1310-
if r.int32() != 0 {
1311-
errorf("unexpected authentication response: %q", t)
1319+
if r.int32() != int(proto.AuthReqOk) {
1320+
return fmt.Errorf("pq: unexpected authentication response: %q", t)
13121321
}
1313-
case 7: // GSSAPI, startup
1322+
return nil
1323+
1324+
case proto.AuthReqGSS: // GSSAPI, startup
13141325
if newGss == nil {
1315-
errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)")
1326+
return fmt.Errorf("pq: kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos)")
13161327
}
13171328
cli, err := newGss()
13181329
if err != nil {
1319-
errorf("kerberos error: %s", err.Error())
1330+
return fmt.Errorf("pq: kerberos error: %w", err)
13201331
}
13211332

13221333
var token []byte
1323-
13241334
if spn, ok := o["krbspn"]; ok {
13251335
// Use the supplied SPN if provided..
13261336
token, err = cli.GetInitTokenFromSpn(spn)
@@ -1330,103 +1340,104 @@ func (cn *conn) auth(r *readBuf, o values) {
13301340
if val, ok := o["krbsrvname"]; ok {
13311341
service = val
13321342
}
1333-
13341343
token, err = cli.GetInitToken(o["host"], service)
13351344
}
13361345

13371346
if err != nil {
1338-
errorf("failed to get Kerberos ticket: %q", err)
1347+
return fmt.Errorf("pq: failed to get Kerberos ticket: %w", err)
13391348
}
13401349

1341-
w := cn.writeBuf('p')
1350+
w := cn.writeBuf(byte(proto.GSSResponse))
13421351
w.bytes(token)
13431352
cn.send(w)
13441353

13451354
// Store for GSSAPI continue message
13461355
cn.gss = cli
1356+
return nil
13471357

1348-
case 8: // GSSAPI continue
1349-
1358+
case proto.AuthReqGSSCont: // GSSAPI continue
13501359
if cn.gss == nil {
1351-
errorf("GSSAPI protocol error")
1360+
return errors.New("pq: GSSAPI protocol error")
13521361
}
13531362

1354-
b := []byte(*r)
1355-
1356-
done, tokOut, err := cn.gss.Continue(b)
1363+
done, tokOut, err := cn.gss.Continue([]byte(*r))
13571364
if err == nil && !done {
1358-
w := cn.writeBuf('p')
1365+
w := cn.writeBuf(byte(proto.SASLInitialResponse))
13591366
w.bytes(tokOut)
13601367
cn.send(w)
13611368
}
13621369

1363-
// Errors fall through and read the more detailed message
1364-
// from the server..
1370+
// Errors fall through and read the more detailed message from the
1371+
// server.
1372+
return nil
13651373

1366-
case 10:
1374+
case proto.AuthReqSASL:
13671375
sc := scram.NewClient(sha256.New, o["user"], o["password"])
13681376
sc.Step(nil)
13691377
if sc.Err() != nil {
1370-
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1378+
return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err())
13711379
}
13721380
scOut := sc.Out()
13731381

1374-
w := cn.writeBuf('p')
1382+
w := cn.writeBuf(byte(proto.SASLResponse))
13751383
w.string("SCRAM-SHA-256")
13761384
w.int32(len(scOut))
13771385
w.bytes(scOut)
13781386
cn.send(w)
13791387

13801388
t, r := cn.recv()
1381-
if t != 'R' {
1382-
errorf("unexpected password response: %q", t)
1389+
if t != byte(proto.AuthenticationRequest) {
1390+
return fmt.Errorf("pq: unexpected password response: %q", t)
13831391
}
13841392

1385-
if r.int32() != 11 {
1386-
errorf("unexpected authentication response: %q", t)
1393+
if r.int32() != int(proto.AuthReqSASLCont) {
1394+
return fmt.Errorf("pq: unexpected authentication response: %q", t)
13871395
}
13881396

13891397
nextStep := r.next(len(*r))
13901398
sc.Step(nextStep)
13911399
if sc.Err() != nil {
1392-
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1400+
return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err())
13931401
}
13941402

13951403
scOut = sc.Out()
1396-
w = cn.writeBuf('p')
1404+
w = cn.writeBuf(byte(proto.SASLResponse))
13971405
w.bytes(scOut)
13981406
cn.send(w)
13991407

14001408
t, r = cn.recv()
1401-
if t != 'R' {
1402-
errorf("unexpected password response: %q", t)
1409+
if t != byte(proto.AuthenticationRequest) {
1410+
return fmt.Errorf("pq: unexpected password response: %q", t)
14031411
}
14041412

1405-
if r.int32() != 12 {
1406-
errorf("unexpected authentication response: %q", t)
1413+
if r.int32() != int(proto.AuthReqSASLFin) {
1414+
return fmt.Errorf("pq: unexpected authentication response: %q", t)
14071415
}
14081416

14091417
nextStep = r.next(len(*r))
14101418
sc.Step(nextStep)
14111419
if sc.Err() != nil {
1412-
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1420+
return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err())
14131421
}
14141422

1415-
default:
1416-
errorf("unknown authentication response: %d", code)
1423+
return nil
14171424
}
14181425
}
14191426

14201427
type format int
14211428

1422-
const formatText format = 0
1423-
const formatBinary format = 1
1429+
const (
1430+
formatText format = 0
1431+
formatBinary format = 1
1432+
)
14241433

1425-
// One result-column format code with the value 1 (i.e. all binary).
1426-
var colFmtDataAllBinary = []byte{0, 1, 0, 1}
1434+
var (
1435+
// One result-column format code with the value 1 (i.e. all binary).
1436+
colFmtDataAllBinary = []byte{0, 1, 0, 1}
14271437

1428-
// No result-column format codes (i.e. all text).
1429-
var colFmtDataAllText = []byte{0, 0}
1438+
// No result-column format codes (i.e. all text).
1439+
colFmtDataAllText = []byte{0, 0}
1440+
)
14301441

14311442
type stmt struct {
14321443
cn *conn

conn_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ import (
1818
"github.com/lib/pq/internal/pqtest"
1919
)
2020

21+
// Called for the side-effect of setting the environment.
22+
func init() { pqtest.DSN("") }
23+
2124
const cancelErrorCode ErrorCode = "57014"
2225

2326
func getServerVersion(t *testing.T, db *sql.DB) int {
@@ -2431,3 +2434,63 @@ func TestCommitInFailedTransactionWithCancelContext(t *testing.T) {
24312434
t.Fatalf("expected ErrInFailedTransaction; got %#v", err)
24322435
}
24332436
}
2437+
2438+
func TestAuth(t *testing.T) {
2439+
tests := []struct {
2440+
buf readBuf
2441+
wantErr string
2442+
}{
2443+
{readBuf{0, 0, 0, 9}, `pq: unsupported authentication method: SSPI (9)`},
2444+
{readBuf{0, 0, 0, 99}, `unknown authentication response: <unknown> (99)`},
2445+
}
2446+
2447+
t.Parallel()
2448+
for _, tt := range tests {
2449+
t.Run("", func(t *testing.T) {
2450+
t.Run("unsupported auth", func(t *testing.T) {
2451+
err := (&conn{}).auth(&tt.buf, values{})
2452+
if !pqtest.ErrorContains(err, tt.wantErr) {
2453+
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)
2454+
}
2455+
})
2456+
})
2457+
}
2458+
2459+
t.Run("end to end", func(t *testing.T) {
2460+
pqtest.SkipPgbouncer(t) // TODO: need to properly set up auth
2461+
pqtest.SkipPgpool(t) // TODO: need to properly set up auth
2462+
2463+
tests := []struct {
2464+
conn, wantErr string
2465+
}{
2466+
{"user=pqgomd5", `password authentication failed for user "pqgomd5"`},
2467+
{"user=pqgopassword", `empty password returned by client`},
2468+
{"user=pqgoscram", `password authentication failed for user "pqgoscram"`},
2469+
2470+
{"user=pqgomd5 password=wrong", `password authentication failed for user "pqgomd5"`},
2471+
{"user=pqgopassword password=wrong", `password authentication failed for user "pqgopassword"`},
2472+
{"user=pqgoscram password=wrong", `password authentication failed for user "pqgoscram"`},
2473+
2474+
{"user=pqgomd5 password=wordpass", ``},
2475+
{"user=pqgopassword password=wordpass", ``},
2476+
{"user=pqgoscram password=wordpass", ``},
2477+
2478+
{"user=pqgounknown password=wordpass", `role "pqgounknown" does not exist`},
2479+
}
2480+
2481+
for _, tt := range tests {
2482+
t.Run(tt.conn, func(t *testing.T) {
2483+
t.Parallel()
2484+
db, err := pqtest.DB(tt.conn)
2485+
if err != nil {
2486+
t.Fatal(err)
2487+
}
2488+
2489+
err = db.Ping()
2490+
if !pqtest.ErrorContains(err, tt.wantErr) {
2491+
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)
2492+
}
2493+
})
2494+
}
2495+
})
2496+
}

deprecated.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package pq
2+
3+
// PGError is an interface used by previous versions of pq.
4+
//
5+
// Deprecated: use the Error type. This is never used.
6+
type PGError interface {
7+
Error() string
8+
Fatal() bool
9+
Get(k byte) (v string)
10+
}
11+
12+
// Get implements the legacy PGError interface.
13+
//
14+
// Deprecated: new code should use the fields of the Error struct directly.
15+
func (e *Error) Get(k byte) (v string) {
16+
switch k {
17+
case 'S':
18+
return e.Severity
19+
case 'C':
20+
return string(e.Code)
21+
case 'M':
22+
return e.Message
23+
case 'D':
24+
return e.Detail
25+
case 'H':
26+
return e.Hint
27+
case 'P':
28+
return e.Position
29+
case 'p':
30+
return e.InternalPosition
31+
case 'q':
32+
return e.InternalQuery
33+
case 'W':
34+
return e.Where
35+
case 's':
36+
return e.Schema
37+
case 't':
38+
return e.Table
39+
case 'c':
40+
return e.Column
41+
case 'd':
42+
return e.DataTypeName
43+
case 'n':
44+
return e.Constraint
45+
case 'F':
46+
return e.File
47+
case 'L':
48+
return e.Line
49+
case 'R':
50+
return e.Routine
51+
}
52+
return ""
53+
}
54+
55+
// ParseURL converts a url to a connection string for driver.Open.
56+
//
57+
// Example:
58+
//
59+
// "postgres://bob:[email protected]:5432/mydb?sslmode=verify-full"
60+
//
61+
// converts to:
62+
//
63+
// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
64+
//
65+
// Deprecated: directly passing an URL to sql.Open("postgres", "postgres://...")
66+
// now works, and calling this manually is no longer required.
67+
func ParseURL(url string) (string, error) { return parseURL(url) }

0 commit comments

Comments
 (0)