Skip to content

Commit ba24acc

Browse files
andrea-mgnAndrea Magnetto
andauthored
Fix GUID conversion (#207)
Co-authored-by: Andrea Magnetto <[email protected]> Adds a connection string value to preserve the raw guid byte order returned by SQL Server.
1 parent dad23d2 commit ba24acc

20 files changed

+223
-54
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Other supported formats are listed below.
6868
* `multisubnetfailover`
6969
* `true` (Default) Client attempt to connect to all IPs simultaneously.
7070
* `false` Client attempts to connect to IPs in serial.
71+
* `guid conversion` - Enables the conversion of GUIDs, so that byte order is preserved. UniqueIdentifier isn't supported for nullable fields, NullUniqueIdentifier must be used instead.
7172

7273
### Connection parameters for namedpipe package
7374
* `pipe` - If set, no Browser query is made and named pipe used will be `\\<host>\pipe\<pipe>`

alwaysencrypted_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
213213
func testProviderErrorHandling(t *testing.T, name string, provider aecmk.ColumnEncryptionKeyProvider, sel string, insert string, insertArgs []interface{}) {
214214
t.Helper()
215215
testProvider := &testKeyProvider{fallback: provider}
216-
connector, _ := getTestConnector(t)
216+
connector, _ := getTestConnector(t, false /*guidConversion*/)
217217
connector.RegisterCekProvider(name, testProvider)
218218
conn := sql.OpenDB(connector)
219219
defer conn.Close()

bulkcopy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ func (b *Bulk) createColMetadata() []byte {
264264
}
265265
binary.Write(buf, binary.LittleEndian, uint16(col.Flags))
266266

267-
writeTypeInfo(buf, &b.bulkColumns[i].ti, false)
267+
writeTypeInfo(buf, &b.bulkColumns[i].ti, false, b.cn.sess.encoding)
268268

269269
if col.ti.TypeId == typeNText ||
270270
col.ti.TypeId == typeText ||

bulkcopy_test.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ func TestBulkcopyWithInvalidNullableType(t *testing.T) {
111111
}
112112
}
113113

114-
func TestBulkcopy(t *testing.T) {
114+
func testBulkcopy(t *testing.T, guidConversion bool) {
115115
// TDS level Bulk Insert is not supported on Azure SQL Server.
116-
if dsn := makeConnStr(t); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
116+
if dsn := makeConnStrSettingGuidConversion(t, guidConversion); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
117117
t.Skip("TDS level bulk copy is not supported on Azure SQL Server")
118118
}
119119
type testValue struct {
@@ -300,6 +300,14 @@ func TestBulkcopy(t *testing.T) {
300300
}
301301
}
302302

303+
func TestBulkcopyWithGuidConversion(t *testing.T) {
304+
testBulkcopy(t, true /*guidConversion*/)
305+
}
306+
307+
func TestBulkcopy(t *testing.T) {
308+
testBulkcopy(t, false /*guidConversion*/)
309+
}
310+
303311
func compareValue(a interface{}, expected interface{}) bool {
304312
if got, ok := a.([]uint8); ok {
305313
if _, ok := expected.([]uint8); !ok {

msdsn/conn_str.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,14 @@ const (
8484
Pipe = "pipe"
8585
MultiSubnetFailover = "multisubnetfailover"
8686
NoTraceID = "notraceid"
87+
GuidConversion = "guid conversion"
8788
)
8889

90+
type EncodeParameters struct {
91+
// Properly convert GUIDs, using correct byte endianness
92+
GuidConversion bool
93+
}
94+
8995
type Config struct {
9096
Port uint64
9197
Host string
@@ -141,6 +147,8 @@ type Config struct {
141147
// When true, no connection id or trace id value is sent in the prelogin packet.
142148
// Some cloud servers may block connections that lack such values.
143149
NoTraceID bool
150+
// Parameters related to type encoding
151+
Encoding EncodeParameters
144152
}
145153

146154
func readDERFile(filename string) ([]byte, error) {
@@ -525,6 +533,20 @@ func Parse(dsn string) (Config, error) {
525533
p.NoTraceID = notraceid
526534
}
527535
}
536+
537+
guidConversion, ok := params[GuidConversion]
538+
if ok {
539+
var err error
540+
p.Encoding.GuidConversion, err = strconv.ParseBool(guidConversion)
541+
if err != nil {
542+
f := "invalid guid conversion '%s': %s"
543+
return p, fmt.Errorf(f, guidConversion, err.Error())
544+
}
545+
} else {
546+
// set to false for backward compatibility
547+
p.Encoding.GuidConversion = false
548+
}
549+
528550
return p, nil
529551
}
530552

@@ -585,6 +607,11 @@ func (p Config) URL() *url.URL {
585607
if p.ColumnEncryption {
586608
q.Add("columnencryption", "true")
587609
}
610+
611+
if p.Encoding.GuidConversion {
612+
q.Add(GuidConversion, strconv.FormatBool(p.Encoding.GuidConversion))
613+
}
614+
588615
if len(q) > 0 {
589616
res.RawQuery = q.Encode()
590617
}

msdsn/conn_str_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,10 @@ func TestValidConnectionString(t *testing.T) {
190190
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry && !p.ColumnEncryption
191191
}},
192192
{"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1", func(p Config) bool {
193-
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption
193+
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && !p.Encoding.GuidConversion
194+
}},
195+
{"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1&guid+conversion=true", func(p Config) bool {
196+
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && p.Encoding.GuidConversion
194197
}},
195198
}
196199
for _, ts := range connStrings {

mssql.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
554554
params[0] = makeStrParam(s.query)
555555
params[1] = makeStrParam(strings.Join(decls, ","))
556556
}
557-
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
557+
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset, conn.sess.encoding); err != nil {
558558
conn.sess.LogF(ctx, msdsn.LogErrors, "Failed to send Rpc with %v", err)
559559
conn.connectionGood = false
560560
return fmt.Errorf("failed to send RPC: %v", err)

mssql_go19.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
206206
err = errCalTypes
207207
return
208208
}
209-
res.buffer, err = val.encode(schema, name, columnStr, tvpFieldIndexes)
209+
res.buffer, err = val.encode(schema, name, columnStr, tvpFieldIndexes, s.c.sess.encoding)
210210
if err != nil {
211211
return
212212
}

queries_test.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ func driverWithProcess(t *testing.T, tl Logger) *Driver {
2727
}
2828
}
2929

30-
func TestSelect(t *testing.T) {
31-
conn, logger := open(t)
30+
func testSelect(t *testing.T, guidConversion bool) {
31+
conn, logger := openSettingGuidConversion(t, guidConversion)
3232
defer conn.Close()
3333
defer logger.StopLogging()
3434

@@ -39,6 +39,10 @@ func TestSelect(t *testing.T) {
3939
}
4040

4141
longstr := strings.Repeat("x", 10000)
42+
expectedGuid := []byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}
43+
if guidConversion {
44+
expectedGuid = []byte{0xFF, 0x19, 0x96, 0x6F, 0x86, 0x8B, 0x11, 0xD0, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}
45+
}
4246

4347
values := []testStruct{
4448
{"1", int64(1)},
@@ -83,8 +87,7 @@ func TestSelect(t *testing.T) {
8387
{"cast('2079-06-06T23:59:00' as smalldatetime)",
8488
time.Date(2079, 6, 6, 23, 59, 0, 0, time.UTC)},
8589
{"cast(NULL as smalldatetime)", nil},
86-
{"cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier)",
87-
[]byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}},
90+
{"cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier)", expectedGuid},
8891
{"cast(NULL as uniqueidentifier)", nil},
8992
{"cast(0x1234 as varbinary(2))", []byte{0x12, 0x34}},
9093
{"cast(N'abc' as nvarchar(max))", "abc"},
@@ -114,8 +117,7 @@ func TestSelect(t *testing.T) {
114117
{"cast(cast(N'chào' as nvarchar(max)) collate Vietnamese_CI_AI as varchar(max))", "chào"}, // cp1258
115118
{fmt.Sprintf("cast(N'%s' as nvarchar(max))", longstr), longstr},
116119
{"cast(NULL as sql_variant)", nil},
117-
{"cast(cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier) as sql_variant)",
118-
[]byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}},
120+
{"cast(cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier) as sql_variant)", expectedGuid},
119121
{"cast(cast(1 as bit) as sql_variant)", true},
120122
{"cast(cast(10 as tinyint) as sql_variant)", int64(10)},
121123
{"cast(cast(-10 as smallint) as sql_variant)", int64(-10)},
@@ -214,6 +216,14 @@ func TestSelect(t *testing.T) {
214216
})
215217
}
216218

219+
func TestSelectWithGuidConversion(t *testing.T) {
220+
testSelect(t, true /*guidConversion*/)
221+
}
222+
223+
func TestSelect(t *testing.T) {
224+
testSelect(t, false /*guidConversion*/)
225+
}
226+
217227
func TestSelectDateTimeOffset(t *testing.T) {
218228
type testStruct struct {
219229
sql string

rpc.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package mssql
22

33
import (
44
"encoding/binary"
5+
6+
"github.com/microsoft/go-mssqldb/msdsn"
57
)
68

79
type procId struct {
@@ -43,7 +45,7 @@ var (
4345
)
4446

4547
// http://msdn.microsoft.com/en-us/library/dd357576.aspx
46-
func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) (err error) {
48+
func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool, encoding msdsn.EncodeParameters) (err error) {
4749
buf.BeginPacket(packRPCRequest, resetSession)
4850
writeAllHeaders(buf, headers)
4951
if len(proc.name) == 0 {
@@ -73,7 +75,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
7375
if err = binary.Write(buf, binary.LittleEndian, param.Flags); err != nil {
7476
return
7577
}
76-
err = writeTypeInfo(buf, &param.ti, (param.Flags&fByRevValue) != 0)
78+
err = writeTypeInfo(buf, &param.ti, (param.Flags&fByRevValue) != 0, encoding)
7779
if err != nil {
7880
return
7981
}
@@ -82,7 +84,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
8284
return
8385
}
8486
if (param.Flags & fEncrypted) == fEncrypted {
85-
err = writeTypeInfo(buf, &param.tiOriginal, false)
87+
err = writeTypeInfo(buf, &param.tiOriginal, false, encoding)
8688
if err != nil {
8789
return
8890
}

0 commit comments

Comments
 (0)