Skip to content

Commit dad23d2

Browse files
authored
Feat: Add tracing data to prelogin and login7 packets (#228)
* add traceid field to prelogin * add clientid and pid to login7 * add logging of conn id to tdsSession * fix go mod * fix test * update min Go to 118 Fixes #226
1 parent 2521238 commit dad23d2

File tree

13 files changed

+306
-236
lines changed

13 files changed

+306
-236
lines changed

appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ environment:
1111
SQLUSER: sa
1212
SQLPASSWORD: Password12!
1313
DATABASE: test
14-
GOVERSION: 117
14+
GOVERSION: 118
1515
COLUMNENCRYPTION:
1616
APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
1717
RACE: -race -cpu 4

bulkcopy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,6 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
627627

628628
func (b *Bulk) dlogf(ctx context.Context, format string, v ...interface{}) {
629629
if b.Debug {
630-
b.cn.sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf(format, v...))
630+
b.cn.sess.LogF(ctx, msdsn.LogDebug, format, v...)
631631
}
632632
}

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
module github.com/microsoft/go-mssqldb
22

3-
go 1.17
3+
go 1.18
44

55
require (
66
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
77
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0
88
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1
99
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9
1010
github.com/golang-sql/sqlexp v0.1.0
11+
github.com/google/uuid v1.6.0
1112
github.com/jcmturner/gokrb5/v8 v8.4.4
1213
github.com/stretchr/testify v1.9.0
1314
golang.org/x/crypto v0.24.0
@@ -21,7 +22,6 @@ require (
2122
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
2223
github.com/davecgh/go-spew v1.1.1 // indirect
2324
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
24-
github.com/google/uuid v1.6.0 // indirect
2525
github.com/hashicorp/go-uuid v1.0.3 // indirect
2626
github.com/jcmturner/aescts/v2 v2.0.0 // indirect
2727
github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect

go.sum

Lines changed: 0 additions & 90 deletions
Large diffs are not rendered by default.

msdsn/conn_str.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"strings"
1515
"time"
1616
"unicode"
17+
18+
"github.com/google/uuid"
1719
)
1820

1921
type (
@@ -44,6 +46,8 @@ const (
4446
LogTransaction Log = 32
4547
LogDebug Log = 64
4648
LogRetries Log = 128
49+
// LogSessionIDs tells the session logger to include activity id and connection id
50+
LogSessionIDs Log = 0x8000
4751
)
4852

4953
const (
@@ -79,6 +83,7 @@ const (
7983
DialTimeout = "dial timeout"
8084
Pipe = "pipe"
8185
MultiSubnetFailover = "multisubnetfailover"
86+
NoTraceID = "notraceid"
8287
)
8388

8489
type Config struct {
@@ -131,6 +136,11 @@ type Config struct {
131136
ColumnEncryption bool
132137
// Attempt to connect to all IPs in parallel when MultiSubnetFailover is true
133138
MultiSubnetFailover bool
139+
// guid to set as Activity Id in the prelogin packet. Defaults to a new value for each Config.
140+
ActivityID []byte
141+
// When true, no connection id or trace id value is sent in the prelogin packet.
142+
// Some cloud servers may block connections that lack such values.
143+
NoTraceID bool
134144
}
135145

136146
func readDERFile(filename string) ([]byte, error) {
@@ -285,6 +295,10 @@ func Parse(dsn string) (Config, error) {
285295
Protocols: []string{},
286296
}
287297

298+
activityid, uerr := uuid.NewRandom()
299+
if uerr == nil {
300+
p.ActivityID = activityid[:]
301+
}
288302
var params map[string]string
289303
var err error
290304

@@ -504,6 +518,13 @@ func Parse(dsn string) (Config, error) {
504518
// Defaulting to true to prevent breaking change although other client libraries default to false
505519
p.MultiSubnetFailover = true
506520
}
521+
nti, ok := params[NoTraceID]
522+
if ok {
523+
notraceid, err := strconv.ParseBool(nti)
524+
if err == nil {
525+
p.NoTraceID = notraceid
526+
}
527+
}
507528
return p, nil
508529
}
509530

msdsn/conn_str_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,8 @@ func TestValidConnectionString(t *testing.T) {
105105
{"disableretry=1", func(p Config) bool { return p.DisableRetry }},
106106
{"disableretry=0", func(p Config) bool { return !p.DisableRetry }},
107107
{"", func(p Config) bool { return p.DisableRetry == disableRetryDefault }},
108-
{"MultiSubnetFailover=true", func(p Config) bool { return p.MultiSubnetFailover }},
108+
{"MultiSubnetFailover=true;NoTraceID=true", func(p Config) bool { return p.MultiSubnetFailover && p.NoTraceID }},
109109
{"MultiSubnetFailover=false", func(p Config) bool { return !p.MultiSubnetFailover }},
110-
111110
// those are supported currently, but maybe should not be
112111
{"someparam", func(p Config) bool { return true }},
113112
{";;=;", func(p Config) bool { return true }},
@@ -226,6 +225,9 @@ func TestConnParseRoundTripFixed(t *testing.T) {
226225
if err != nil {
227226
t.Fatal("Params after roundtrip are not valid", err)
228227
}
228+
t.Log("params.URL " + params.URL().String())
229+
params.ActivityID = nil
230+
rtParams.ActivityID = nil
229231
if !reflect.DeepEqual(params, rtParams) {
230232
t.Fatal("Parameters do not match after roundtrip", params, rtParams)
231233
}

mssql.go

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,7 @@ func (c *Conn) checkBadConn(ctx context.Context, err error, mayRetry bool) error
281281
}
282282

283283
if !c.connectionGood && mayRetry && !c.connector.params.DisableRetry {
284-
if c.sess.logFlags&logRetries != 0 {
285-
c.sess.logger.Log(ctx, msdsn.LogRetries, err.Error())
286-
}
284+
c.sess.Log(ctx, msdsn.LogRetries, err.Error)
287285
return newRetryableError(err)
288286
}
289287

@@ -324,9 +322,7 @@ func (c *Conn) sendCommitRequest() error {
324322
reset := c.resetSession
325323
c.resetSession = false
326324
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
327-
if c.sess.logFlags&logErrors != 0 {
328-
c.sess.logger.Log(c.transactionCtx, msdsn.LogErrors, fmt.Sprintf("Failed to send CommitXact with %v", err))
329-
}
325+
c.sess.LogF(c.transactionCtx, msdsn.LogErrors, "Failed to send CommitXact with %v", err)
330326
c.connectionGood = false
331327
return fmt.Errorf("faild to send CommitXact: %v", err)
332328
}
@@ -351,9 +347,7 @@ func (c *Conn) sendRollbackRequest() error {
351347
reset := c.resetSession
352348
c.resetSession = false
353349
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
354-
if c.sess.logFlags&logErrors != 0 {
355-
c.sess.logger.Log(c.transactionCtx, msdsn.LogErrors, fmt.Sprintf("Failed to send RollbackXact with %v", err))
356-
}
350+
c.sess.LogF(c.transactionCtx, msdsn.LogErrors, "Failed to send RollbackXact with %v", err)
357351
c.connectionGood = false
358352
return fmt.Errorf("failed to send RollbackXact: %v", err)
359353
}
@@ -388,9 +382,7 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro
388382
reset := c.resetSession
389383
c.resetSession = false
390384
if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil {
391-
if c.sess.logFlags&logErrors != 0 {
392-
c.sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Failed to send BeginXact with %v", err))
393-
}
385+
c.sess.LogF(ctx, msdsn.LogErrors, "Failed to send BeginXact with %v", err)
394386
c.connectionGood = false
395387
return fmt.Errorf("failed to send BeginXact: %v", err)
396388
}
@@ -524,15 +516,13 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
524516
conn := s.c
525517

526518
// no need to check number of parameters here, it is checked by database/sql
527-
if conn.sess.logFlags&logSQL != 0 {
528-
conn.sess.logger.Log(ctx, msdsn.LogSQL, s.query)
529-
}
519+
conn.sess.LogS(ctx, msdsn.LogSQL, s.query)
530520
if conn.sess.logFlags&logParams != 0 && len(args) > 0 {
531521
for i := 0; i < len(args); i++ {
532522
if len(args[i].Name) > 0 {
533-
s.c.sess.logger.Log(ctx, msdsn.LogParams, fmt.Sprintf("\t@%s\t%v", args[i].Name, args[i].Value))
523+
s.c.sess.LogF(ctx, msdsn.LogParams, "\t@%s\t%v", args[i].Name, args[i].Value)
534524
} else {
535-
s.c.sess.logger.Log(ctx, msdsn.LogParams, fmt.Sprintf("\t@p%d\t%v", i+1, args[i].Value))
525+
s.c.sess.LogF(ctx, msdsn.LogParams, "\t@p%d\t%v", i+1, args[i].Value)
536526
}
537527
}
538528
}
@@ -542,9 +532,7 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
542532
isProc := isProc(s.query)
543533
if len(args) == 0 && !isProc {
544534
if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
545-
if conn.sess.logFlags&logErrors != 0 {
546-
conn.sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Failed to send SqlBatch with %v", err))
547-
}
535+
conn.sess.LogF(ctx, msdsn.LogErrors, "Failed to send SqlBatch with %v", err)
548536
conn.connectionGood = false
549537
return fmt.Errorf("failed to send SQL Batch: %v", err)
550538
}
@@ -567,9 +555,7 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
567555
params[1] = makeStrParam(strings.Join(decls, ","))
568556
}
569557
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
570-
if conn.sess.logFlags&logErrors != 0 {
571-
conn.sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Failed to send Rpc with %v", err))
572-
}
558+
conn.sess.LogF(ctx, msdsn.LogErrors, "Failed to send Rpc with %v", err)
573559
conn.connectionGood = false
574560
return fmt.Errorf("failed to send RPC: %v", err)
575561
}
@@ -1298,9 +1284,7 @@ func (rc *Rowsq) Columns() (res []string) {
12981284
for {
12991285
tok, err := rc.reader.nextToken()
13001286
if err == nil {
1301-
if rc.reader.sess.logFlags&logDebug != 0 {
1302-
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("Columns() token type:%v", reflect.TypeOf(tok)))
1303-
}
1287+
rc.reader.sess.LogF(rc.reader.ctx, msdsn.LogDebug, "Columns() token type:%v", reflect.TypeOf(tok))
13041288
if tok == nil {
13051289
return []string{}
13061290
} else {
@@ -1327,9 +1311,7 @@ func (rc *Rowsq) Next(dest []driver.Value) error {
13271311
}
13281312
for {
13291313
tok, err := rc.reader.nextToken()
1330-
if rc.reader.sess.logFlags&logDebug != 0 {
1331-
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("Next() token type:%v", reflect.TypeOf(tok)))
1332-
}
1314+
rc.reader.sess.LogF(rc.reader.ctx, msdsn.LogDebug, "Next() token type:%v", reflect.TypeOf(tok))
13331315
if err == nil {
13341316
if tok == nil {
13351317
return io.EOF
@@ -1391,9 +1373,7 @@ func (rc *Rowsq) NextResultSet() error {
13911373
scan:
13921374
for {
13931375
tok, err := rc.reader.nextToken()
1394-
if rc.reader.sess.logFlags&logDebug != 0 {
1395-
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("NextResultSet() token type:%v", reflect.TypeOf(tok)))
1396-
}
1376+
rc.reader.sess.LogF(rc.reader.ctx, msdsn.LogDebug, "NextResultSet() token type:%v", reflect.TypeOf(tok))
13971377

13981378
if err != nil {
13991379
return err

session.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package mssql
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
"github.com/google/uuid"
8+
"github.com/microsoft/go-mssqldb/aecmk"
9+
"github.com/microsoft/go-mssqldb/msdsn"
10+
)
11+
12+
func newSession(outbuf *tdsBuffer, logger ContextLogger, p msdsn.Config) *tdsSession {
13+
sess := &tdsSession{
14+
buf: outbuf,
15+
logger: logger,
16+
logFlags: uint64(p.LogFlags),
17+
aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()},
18+
}
19+
_ = sess.activityid.Scan(p.ActivityID)
20+
// generating a guid has a small chance of failure. Make a best effort
21+
connid, cerr := uuid.NewRandom()
22+
if cerr == nil {
23+
_ = sess.connid.Scan(connid[:])
24+
}
25+
26+
return sess
27+
}
28+
29+
func (s *tdsSession) preparePreloginFields(ctx context.Context, p msdsn.Config, fe *featureExtFedAuth) map[uint8][]byte {
30+
instance_buf := []byte(p.Instance)
31+
instance_buf = append(instance_buf, 0) // zero terminate instance name
32+
33+
var encrypt byte
34+
switch p.Encryption {
35+
default:
36+
panic(fmt.Errorf("Unsupported Encryption Config %v", p.Encryption))
37+
case msdsn.EncryptionDisabled:
38+
encrypt = encryptNotSup
39+
case msdsn.EncryptionRequired:
40+
encrypt = encryptOn
41+
case msdsn.EncryptionOff:
42+
encrypt = encryptOff
43+
case msdsn.EncryptionStrict:
44+
encrypt = encryptStrict
45+
}
46+
v := getDriverVersion(driverVersion)
47+
fields := map[uint8][]byte{
48+
// 4 bytes for version and 2 bytes for minor version
49+
preloginVERSION: {byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24), 0, 0},
50+
preloginENCRYPTION: {encrypt},
51+
preloginINSTOPT: instance_buf,
52+
preloginTHREADID: {0, 0, 0, 0},
53+
preloginMARS: {0}, // MARS disabled
54+
}
55+
56+
if !p.NoTraceID {
57+
traceID := make([]byte, 36) // 16 byte connection id + 16 byte activity id + 4 byte sequence number
58+
connid, _ := s.connid.Value()
59+
activityid, _ := s.activityid.Value()
60+
_ = copy(traceID[:16], connid.([]byte))
61+
_ = copy(traceID[16:32], activityid.([]byte))
62+
fields[preloginTRACEID] = traceID
63+
if (s.logFlags)&logDebug != 0 {
64+
msg := fmt.Sprintf("Creating prelogin packet with connection id '%s' and activity id '%s'", s.connid, s.activityid)
65+
s.logger.Log(ctx, msdsn.LogDebug, msg)
66+
}
67+
}
68+
if fe.FedAuthLibrary != FedAuthLibraryReserved {
69+
fields[preloginFEDAUTHREQUIRED] = []byte{1}
70+
}
71+
72+
return fields
73+
}
74+
75+
type logFunc func() string
76+
77+
func (s *tdsSession) logPrefix() string {
78+
if s.logFlags&uint64(msdsn.LogSessionIDs) != 0 {
79+
return fmt.Sprintf("aid:%v cid:%v - ", s.activityid, s.connid)
80+
}
81+
return ""
82+
}
83+
84+
func (s *tdsSession) LogS(ctx context.Context, category msdsn.Log, msg string) {
85+
s.Log(ctx, category, func() string { return msg })
86+
}
87+
88+
// Log checks that the session logFlags includes the category before evaluating the logFunc and emitting the trace
89+
func (s *tdsSession) Log(ctx context.Context, category msdsn.Log, logFunc logFunc) {
90+
if s.logFlags&uint64(category) != 0 {
91+
s.logger.Log(ctx, category, s.logPrefix()+logFunc())
92+
}
93+
}
94+
95+
// LogF checks that the session logFlags includes the category before calling fmt.Sprintf and emitting the trace
96+
func (s *tdsSession) LogF(ctx context.Context, category msdsn.Log, format string, a ...any) {
97+
if s.logFlags&uint64(category) != 0 {
98+
s.logger.Log(ctx, category, s.logPrefix()+fmt.Sprintf(format, a...))
99+
}
100+
}

0 commit comments

Comments
 (0)