Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 74 additions & 63 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"
"unicode"

"github.com/lib/pq/internal/proto"
"github.com/lib/pq/oid"
"github.com/lib/pq/scram"
)
Expand Down Expand Up @@ -1262,15 +1263,17 @@ func (cn *conn) startup(o values) {
}

for {
t, r := cn.recv()
switch t {
case 'K':
switch t, r := cn.recv(); proto.ResponseCode(t) {
case proto.BackendKeyData:
cn.processBackendKeyData(r)
case 'S':
case proto.ParameterStatus:
cn.processParameterStatus(r)
case 'R':
cn.auth(r, o)
case 'Z':
case proto.AuthenticationRequest:
err := cn.auth(r, o)
if err != nil {
panic(err)
}
case proto.ReadyForQuery:
cn.processReadyForQuery(r)
return
default:
Expand All @@ -1279,48 +1282,55 @@ func (cn *conn) startup(o values) {
}
}

func (cn *conn) auth(r *readBuf, o values) {
switch code := r.int32(); code {
case 0:
// OK
case 3:
w := cn.writeBuf('p')
func (cn *conn) auth(r *readBuf, o values) error {
switch code := proto.AuthCode(r.int32()); code {
default:
return fmt.Errorf("pq: unknown authentication response: %s", code)
case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI:
return fmt.Errorf("pq: unsupported authentication method: %s", code)

case proto.AuthReqOk:
return nil

case proto.AuthReqPassword:
w := cn.writeBuf(byte(proto.PasswordMessage))
w.string(o["password"])
cn.send(w)

t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
if t != byte(proto.AuthenticationRequest) {
return fmt.Errorf("pq: unexpected password response: %q", t)
}

if r.int32() != 0 {
errorf("unexpected authentication response: %q", t)
if r.int32() != int(proto.AuthReqOk) {
return fmt.Errorf("pq: unexpected authentication response: %q", t)
}
case 5:
return nil

case proto.AuthReqMD5:
s := string(r.next(4))
w := cn.writeBuf('p')
w := cn.writeBuf(byte(proto.PasswordMessage))
w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
cn.send(w)

t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
if t != byte(proto.AuthenticationRequest) {
return fmt.Errorf("pq: unexpected password response: %q", t)
}

if r.int32() != 0 {
errorf("unexpected authentication response: %q", t)
if r.int32() != int(proto.AuthReqOk) {
return fmt.Errorf("pq: unexpected authentication response: %q", t)
}
case 7: // GSSAPI, startup
return nil

case proto.AuthReqGSS: // GSSAPI, startup
if newGss == nil {
errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)")
return fmt.Errorf("pq: kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos)")
}
cli, err := newGss()
if err != nil {
errorf("kerberos error: %s", err.Error())
return fmt.Errorf("pq: kerberos error: %w", err)
}

var token []byte

if spn, ok := o["krbspn"]; ok {
// Use the supplied SPN if provided..
token, err = cli.GetInitTokenFromSpn(spn)
Expand All @@ -1330,103 +1340,104 @@ func (cn *conn) auth(r *readBuf, o values) {
if val, ok := o["krbsrvname"]; ok {
service = val
}

token, err = cli.GetInitToken(o["host"], service)
}

if err != nil {
errorf("failed to get Kerberos ticket: %q", err)
return fmt.Errorf("pq: failed to get Kerberos ticket: %w", err)
}

w := cn.writeBuf('p')
w := cn.writeBuf(byte(proto.GSSResponse))
w.bytes(token)
cn.send(w)

// Store for GSSAPI continue message
cn.gss = cli
return nil

case 8: // GSSAPI continue

case proto.AuthReqGSSCont: // GSSAPI continue
if cn.gss == nil {
errorf("GSSAPI protocol error")
return errors.New("pq: GSSAPI protocol error")
}

b := []byte(*r)

done, tokOut, err := cn.gss.Continue(b)
done, tokOut, err := cn.gss.Continue([]byte(*r))
if err == nil && !done {
w := cn.writeBuf('p')
w := cn.writeBuf(byte(proto.SASLInitialResponse))
w.bytes(tokOut)
cn.send(w)
}

// Errors fall through and read the more detailed message
// from the server..
// Errors fall through and read the more detailed message from the
// server.
return nil

case 10:
case proto.AuthReqSASL:
sc := scram.NewClient(sha256.New, o["user"], o["password"])
sc.Step(nil)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err())
}
scOut := sc.Out()

w := cn.writeBuf('p')
w := cn.writeBuf(byte(proto.SASLResponse))
w.string("SCRAM-SHA-256")
w.int32(len(scOut))
w.bytes(scOut)
cn.send(w)

t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
if t != byte(proto.AuthenticationRequest) {
return fmt.Errorf("pq: unexpected password response: %q", t)
}

if r.int32() != 11 {
errorf("unexpected authentication response: %q", t)
if r.int32() != int(proto.AuthReqSASLCont) {
return fmt.Errorf("pq: unexpected authentication response: %q", t)
}

nextStep := r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err())
}

scOut = sc.Out()
w = cn.writeBuf('p')
w = cn.writeBuf(byte(proto.SASLResponse))
w.bytes(scOut)
cn.send(w)

t, r = cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
if t != byte(proto.AuthenticationRequest) {
return fmt.Errorf("pq: unexpected password response: %q", t)
}

if r.int32() != 12 {
errorf("unexpected authentication response: %q", t)
if r.int32() != int(proto.AuthReqSASLFin) {
return fmt.Errorf("pq: unexpected authentication response: %q", t)
}

nextStep = r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err())
}

default:
errorf("unknown authentication response: %d", code)
return nil
}
}

type format int

const formatText format = 0
const formatBinary format = 1
const (
formatText format = 0
formatBinary format = 1
)

// One result-column format code with the value 1 (i.e. all binary).
var colFmtDataAllBinary = []byte{0, 1, 0, 1}
var (
// One result-column format code with the value 1 (i.e. all binary).
colFmtDataAllBinary = []byte{0, 1, 0, 1}

// No result-column format codes (i.e. all text).
var colFmtDataAllText = []byte{0, 0}
// No result-column format codes (i.e. all text).
colFmtDataAllText = []byte{0, 0}
)

type stmt struct {
cn *conn
Expand Down
63 changes: 63 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (
"github.com/lib/pq/internal/pqtest"
)

// Called for the side-effect of setting the environment.
func init() { pqtest.DSN("") }

const cancelErrorCode ErrorCode = "57014"

func getServerVersion(t *testing.T, db *sql.DB) int {
Expand Down Expand Up @@ -2431,3 +2434,63 @@ func TestCommitInFailedTransactionWithCancelContext(t *testing.T) {
t.Fatalf("expected ErrInFailedTransaction; got %#v", err)
}
}

func TestAuth(t *testing.T) {
tests := []struct {
buf readBuf
wantErr string
}{
{readBuf{0, 0, 0, 9}, `pq: unsupported authentication method: SSPI (9)`},
{readBuf{0, 0, 0, 99}, `unknown authentication response: <unknown> (99)`},
}

t.Parallel()
for _, tt := range tests {
t.Run("", func(t *testing.T) {
t.Run("unsupported auth", func(t *testing.T) {
err := (&conn{}).auth(&tt.buf, values{})
if !pqtest.ErrorContains(err, tt.wantErr) {
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)
}
})
})
}

t.Run("end to end", func(t *testing.T) {
pqtest.SkipPgbouncer(t) // TODO: need to properly set up auth
pqtest.SkipPgpool(t) // TODO: need to properly set up auth

tests := []struct {
conn, wantErr string
}{
{"user=pqgomd5", `password authentication failed for user "pqgomd5"`},
{"user=pqgopassword", `empty password returned by client`},
{"user=pqgoscram", `password authentication failed for user "pqgoscram"`},

{"user=pqgomd5 password=wrong", `password authentication failed for user "pqgomd5"`},
{"user=pqgopassword password=wrong", `password authentication failed for user "pqgopassword"`},
{"user=pqgoscram password=wrong", `password authentication failed for user "pqgoscram"`},

{"user=pqgomd5 password=wordpass", ``},
{"user=pqgopassword password=wordpass", ``},
{"user=pqgoscram password=wordpass", ``},

{"user=pqgounknown password=wordpass", `role "pqgounknown" does not exist`},
}

for _, tt := range tests {
t.Run(tt.conn, func(t *testing.T) {
t.Parallel()
db, err := pqtest.DB(tt.conn)
if err != nil {
t.Fatal(err)
}

err = db.Ping()
if !pqtest.ErrorContains(err, tt.wantErr) {
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)
}
})
}
})
}
67 changes: 67 additions & 0 deletions deprecated.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package pq

// PGError is an interface used by previous versions of pq.
//
// Deprecated: use the Error type. This is never used.
type PGError interface {
Error() string
Fatal() bool
Get(k byte) (v string)
}

// Get implements the legacy PGError interface.
//
// Deprecated: new code should use the fields of the Error struct directly.
func (e *Error) Get(k byte) (v string) {
switch k {
case 'S':
return e.Severity
case 'C':
return string(e.Code)
case 'M':
return e.Message
case 'D':
return e.Detail
case 'H':
return e.Hint
case 'P':
return e.Position
case 'p':
return e.InternalPosition
case 'q':
return e.InternalQuery
case 'W':
return e.Where
case 's':
return e.Schema
case 't':
return e.Table
case 'c':
return e.Column
case 'd':
return e.DataTypeName
case 'n':
return e.Constraint
case 'F':
return e.File
case 'L':
return e.Line
case 'R':
return e.Routine
}
return ""
}

// ParseURL converts a url to a connection string for driver.Open.
//
// Example:
//
// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
//
// converts to:
//
// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
//
// Deprecated: directly passing an URL to sql.Open("postgres", "postgres://...")
// now works, and calling this manually is no longer required.
func ParseURL(url string) (string, error) { return parseURL(url) }
Loading
Loading