Skip to content

Commit 1744072

Browse files
committed
Port gravitational changes to microsoft driver
1 parent eaf0b71 commit 1744072

File tree

9 files changed

+439
-260
lines changed

9 files changed

+439
-260
lines changed

azuread/configuration.go

Lines changed: 238 additions & 223 deletions
Large diffs are not rendered by default.

azuread/driver.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ import (
77
"context"
88
"database/sql"
99
"database/sql/driver"
10-
1110
mssql "github.com/microsoft/go-mssqldb"
11+
"github.com/microsoft/go-mssqldb/msdsn"
1212
)
1313

1414
// DriverName is the name used to register the driver
@@ -43,6 +43,16 @@ func NewConnector(dsn string) (*mssql.Connector, error) {
4343
return newConnectorConfig(config)
4444
}
4545

46+
// NewConnectorFromConfig returns a new connector with the provided configuration and additional parameters
47+
func NewConnectorFromConfig(dsnConfig msdsn.Config, params map[string]string) (*mssql.Connector, error) {
48+
config, err := newConfig(dsnConfig, params)
49+
if err != nil {
50+
return nil, err
51+
}
52+
53+
return newConnectorConfig(config)
54+
}
55+
4656
// newConnectorConfig creates a Connector from config.
4757
func newConnectorConfig(config *azureFedAuthConfig) (*mssql.Connector, error) {
4858
switch config.fedAuthLibrary {

buf.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ var bufpool = sync.Pool{
2828
},
2929
}
3030

31+
type TDSBuffer = tdsBuffer
32+
3133
// tdsBuffer reads and writes TDS packets of data to the transport.
3234
// The write and read buffers are separate to make sending attn signals
3335
// possible without locks. Currently attn signals are only sent during
@@ -59,6 +61,14 @@ type tdsBuffer struct {
5961
afterFirst func()
6062
}
6163

64+
// NewTdsBuffer returns an exported version of *tdsBuffer
65+
func NewTdsBuffer(buff []byte, size int) *TDSBuffer {
66+
return &tdsBuffer{
67+
rbuf: buff,
68+
rsize: size,
69+
}
70+
}
71+
6272
func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
6373

6474
// pull an existing buf if one is available or get and add a new buf to the bufpool

error.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package mssql
22

33
import (
4+
"bytes"
45
"database/sql/driver"
6+
"encoding/binary"
57
"fmt"
68
)
79

@@ -23,6 +25,46 @@ type Error struct {
2325
All []Error
2426
}
2527

28+
// Marshal marshals the error to the wire protocol token.
29+
func (e *Error) Marshal() ([]byte, error) {
30+
buf := bytes.NewBuffer([]byte{
31+
byte(tokenError),
32+
})
33+
length := 2 + // length
34+
4 + // number
35+
1 + // state
36+
1 + // class
37+
(2 + 2*len(e.Message)) + // message
38+
(1 + 2*len(e.ServerName)) + // server name
39+
(1 + 2*len(e.ProcName)) + // proc name
40+
4 // line no
41+
if err := binary.Write(buf, binary.LittleEndian, uint16(length)); err != nil {
42+
return nil, err
43+
}
44+
if err := binary.Write(buf, binary.LittleEndian, e.Number); err != nil {
45+
return nil, err
46+
}
47+
if err := buf.WriteByte(e.State); err != nil {
48+
return nil, err
49+
}
50+
if err := buf.WriteByte(e.Class); err != nil {
51+
return nil, err
52+
}
53+
if err := writeUsVarChar(buf, e.Message); err != nil {
54+
return nil, err
55+
}
56+
if err := writeBVarChar(buf, e.ServerName); err != nil {
57+
return nil, err
58+
}
59+
if err := writeBVarChar(buf, e.ProcName); err != nil {
60+
return nil, err
61+
}
62+
if err := binary.Write(buf, binary.LittleEndian, e.LineNo); err != nil {
63+
return nil, err
64+
}
65+
return buf.Bytes(), nil
66+
}
67+
2668
func (e Error) Error() string {
2769
return "mssql: " + e.Message
2870
}

mssql.go

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ import (
2323

2424
// ReturnStatus may be used to return the return value from a proc.
2525
//
26-
// var rs mssql.ReturnStatus
27-
// _, err := db.Exec("theproc", &rs)
28-
// log.Printf("return status = %d", rs)
26+
// var rs mssql.ReturnStatus
27+
// _, err := db.Exec("theproc", &rs)
28+
// log.Printf("return status = %d", rs)
2929
type ReturnStatus int32
3030

3131
var driverInstance = &Driver{processQueryText: true}
@@ -150,6 +150,12 @@ func NewConnectorConfig(config msdsn.Config) *Connector {
150150
}
151151
}
152152

153+
type auth interface {
154+
InitialBytes() ([]byte, error)
155+
NextBytes([]byte) ([]byte, error)
156+
Free()
157+
}
158+
153159
// Connector holds the parsed DSN and is ready to make a new connection
154160
// at any time.
155161
//
@@ -169,6 +175,9 @@ type Connector struct {
169175
// callback that can provide a security token during ADAL login
170176
adalTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)
171177

178+
// auth allows to provide a custom authenticator.
179+
auth auth
180+
172181
// SessionInitSQL is executed after marking a given session to be reset.
173182
// When not present, the next query will still reset the session to the
174183
// database defaults.
@@ -231,6 +240,16 @@ func (c *Conn) IsValid() bool {
231240
return c.connectionGood
232241
}
233242

243+
// GetUnderlyingConn returns underlying raw server connection.
244+
func (c *Conn) GetUnderlyingConn() io.ReadWriteCloser {
245+
return c.sess.buf.transport
246+
}
247+
248+
// GetLoginFlags returns tokens returned by server during login handshake.
249+
func (c *Conn) GetLoginFlags() []Token {
250+
return c.sess.loginFlags
251+
}
252+
234253
// checkBadConn marks the connection as bad based on the characteristics
235254
// of the supplied error. Bad connections will be dropped from the connection
236255
// pool rather than reused.
@@ -878,22 +897,24 @@ func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
878897
// not a variable length type ok should return false.
879898
// If length is not limited other than system limits, it should return math.MaxInt64.
880899
// The following are examples of returned values for various types:
881-
// TEXT (math.MaxInt64, true)
882-
// varchar(10) (10, true)
883-
// nvarchar(10) (10, true)
884-
// decimal (0, false)
885-
// int (0, false)
886-
// bytea(30) (30, true)
900+
//
901+
// TEXT (math.MaxInt64, true)
902+
// varchar(10) (10, true)
903+
// nvarchar(10) (10, true)
904+
// decimal (0, false)
905+
// int (0, false)
906+
// bytea(30) (30, true)
887907
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
888908
return makeGoLangTypeLength(r.cols[index].ti)
889909
}
890910

891911
// It should return
892912
// the precision and scale for decimal types. If not applicable, ok should be false.
893913
// The following are examples of returned values for various types:
894-
// decimal(38, 4) (38, 4, true)
895-
// int (0, 0, false)
896-
// decimal (math.MaxInt64, math.MaxInt64, true)
914+
//
915+
// decimal(38, 4) (38, 4, true)
916+
// int (0, 0, false)
917+
// decimal (math.MaxInt64, math.MaxInt64, true)
897918
func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
898919
return makeGoLangTypePrecisionScale(r.cols[index].ti)
899920
}
@@ -1320,22 +1341,24 @@ func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string {
13201341
// not a variable length type ok should return false.
13211342
// If length is not limited other than system limits, it should return math.MaxInt64.
13221343
// The following are examples of returned values for various types:
1323-
// TEXT (math.MaxInt64, true)
1324-
// varchar(10) (10, true)
1325-
// nvarchar(10) (10, true)
1326-
// decimal (0, false)
1327-
// int (0, false)
1328-
// bytea(30) (30, true)
1344+
//
1345+
// TEXT (math.MaxInt64, true)
1346+
// varchar(10) (10, true)
1347+
// nvarchar(10) (10, true)
1348+
// decimal (0, false)
1349+
// int (0, false)
1350+
// bytea(30) (30, true)
13291351
func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) {
13301352
return makeGoLangTypeLength(r.cols[index].ti)
13311353
}
13321354

13331355
// It should return
13341356
// the precision and scale for decimal types. If not applicable, ok should be false.
13351357
// The following are examples of returned values for various types:
1336-
// decimal(38, 4) (38, 4, true)
1337-
// int (0, 0, false)
1338-
// decimal (math.MaxInt64, math.MaxInt64, true)
1358+
//
1359+
// decimal(38, 4) (38, 4, true)
1360+
// int (0, 0, false)
1361+
// decimal (math.MaxInt64, math.MaxInt64, true)
13391362
func (r *Rowsq) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
13401363
return makeGoLangTypePrecisionScale(r.cols[index].ti)
13411364
}

tds.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ type tdsSession struct {
143143
logger ContextLogger
144144
routedServer string
145145
routedPort uint16
146+
loginFlags []Token
146147
}
147148

148149
const (
@@ -168,9 +169,23 @@ func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
168169

169170
// http://msdn.microsoft.com/en-us/library/dd357559.aspx
170171
func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte) error {
172+
w.BeginPacket(packetType, false)
173+
if err := WritePreLoginFields(w, fields); err != nil {
174+
return err
175+
}
176+
return w.FinishPacket()
177+
}
178+
179+
// Writer is an interface that combines Writer and ByteWriter.
180+
type Writer interface {
181+
io.Writer
182+
io.ByteWriter
183+
}
184+
185+
// WritePreLoginFields writes provided Pre-Login packet fields into the writer.
186+
func WritePreLoginFields(w Writer, fields map[uint8][]byte) error {
171187
var err error
172188

173-
w.BeginPacket(packetType, false)
174189
offset := uint16(5*len(fields) + 1)
175190
keys := make(keySlice, 0, len(fields))
176191
for k := range fields {
@@ -210,7 +225,7 @@ func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte)
210225
return errors.New("Write method didn't write the whole value")
211226
}
212227
}
213-
return w.FinishPacket()
228+
return nil
214229
}
215230

216231
func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
@@ -1195,6 +1210,15 @@ initiate_connection:
11951210
break
11961211
}
11971212

1213+
// Save options returned by the server so callers implementing
1214+
// proxies can pass them back to the original client.
1215+
switch tok.(type) {
1216+
case envChangeStruct, loginAckStruct, doneStruct:
1217+
if token, ok := tok.(Token); ok {
1218+
sess.loginFlags = append(sess.loginFlags, token)
1219+
}
1220+
}
1221+
11981222
switch token := tok.(type) {
11991223
case sspiMsg:
12001224
sspi_msg, err := auth.NextBytes(token)

0 commit comments

Comments
 (0)