Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions examples/channel_binding/tsql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package main

import (
"bufio"
"context"
"crypto/tls"
"database/sql"
"flag"
"fmt"
"io"
"log"
"os"
"time"

// mssqldb "github.com/denisenkom/go-mssqldb"
// "github.com/denisenkom/go-mssqldb/msdsn"
"github.com/google/uuid"
mssqldb "github.com/microsoft/go-mssqldb"
"github.com/microsoft/go-mssqldb/msdsn"
_ "github.com/microsoft/go-mssqldb/integratedauth/krb5"
)

func main() {
var (
userid = flag.String("U", "", "login_id")
password = flag.String("P", "", "password")
server = flag.String("S", "localhost", "server_name[\\instance_name]")
port = flag.Uint64("p", 1433, "server port")
keyLog = flag.String("K", "tlslog.log", "path to sslkeylog file")
database = flag.String("d", "", "db_name")
spn = flag.String("spn", "", "SPN")
auth = flag.String("a", "ntlm", "Authentication method: ntlm or krb5")
)
flag.Parse()

keyLogFile, err := os.OpenFile(*keyLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
log.Fatal("failed to open keylog file:", err)
}
defer keyLogFile.Close()


cfg := msdsn.Config{
User: *userid,
Database: *database,
Host: *server,
Port: *port,
Password: *password,
ChangePassword: "",
AppName: "go-mssqldb",
ServerSPN: *spn,
TLSConfig: &tls.Config{
InsecureSkipVerify: true, // adjust for your case
ServerName: *server,
KeyLogWriter: keyLogFile,
DynamicRecordSizingDisabled: true,
MinVersion: tls.VersionTLS11,
MaxVersion: tls.VersionTLS12,
},
Encryption: msdsn.EncryptionRequired,

Parameters: map[string]string{
"authenticator": *auth,
"krb5-credcachefile": "/tmp/krb5cc_719880",
"krb5-configfile": "/etc/krb5.conf",
},
ProtocolParameters: map[string]interface{}{

},
Protocols: []string{
"tcp",
},
Encoding: msdsn.EncodeParameters{
Timezone: time.UTC,
GuidConversion: false,
},
DialTimeout: time.Second * 5,
ConnTimeout: time.Second * 10,
KeepAlive: time.Second * 30,

}

// if *spn != "" {
// cfg.Parameters["authenticator"] = "krb5"
// // cfg.Parameters["krb5-credcachefile"] = "/tmp/krb5cc_719880"
// }

activityid, uerr := uuid.NewRandom()
if uerr == nil {
cfg.ActivityID = activityid[:]
}

workstation, err := os.Hostname()
if err == nil {
cfg.Workstation = workstation
}

connector := mssqldb.NewConnectorConfig(cfg)
// dsn := "server=" + *server + ";user id=" + *userid + ";password=" + *password + ";database=" + *database
// connector,err = mssqldb.NewConnector(dsn)
// if err != nil {
// fmt.Println("failed to create connector: ", err.Error())
// return
// }

_, err = connector.Connect(context.Background())
if err != nil {
fmt.Println("connector.Connect: ", err.Error())
return
}

db := sql.OpenDB(connector)
defer db.Close()

// if err != nil {
// fmt.Println("Cannot connect: ", err.Error())
// return
// }
err = db.Ping()
if err != nil {
fmt.Println("Cannot connect: ", err.Error())
return
}
r := bufio.NewReader(os.Stdin)
for {
_, err = os.Stdout.Write([]byte("> "))
if err != nil {
fmt.Println(err)
return
}
cmd, err := r.ReadString('\n')
if err != nil {
if err == io.EOF {
fmt.Println()
return
}
fmt.Println(err)
return
}
err = exec(db, cmd)
if err != nil {
fmt.Println(err)
}
}
}
func exec(db *sql.DB, cmd string) error {
rows, err := db.Query(cmd)
if err != nil {
return err
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
return err
}
if cols == nil {
return nil
}
vals := make([]interface{}, len(cols))
for i := 0; i < len(cols); i++ {
vals[i] = new(interface{})
if i != 0 {
fmt.Print("\t")
}
fmt.Print(cols[i])
}
fmt.Println()
for rows.Next() {
err = rows.Scan(vals...)
if err != nil {
fmt.Println(err)
continue
}
for i := 0; i < len(vals); i++ {
if i != 0 {
fmt.Print("\t")
}
printValue(vals[i].(*interface{}))
}
fmt.Println()

}
if rows.Err() != nil {
return rows.Err()
}
return nil
}

func printValue(pval *interface{}) {
switch v := (*pval).(type) {
case nil:
fmt.Print("NULL")
case bool:
if v {
fmt.Print("1")
} else {
fmt.Print("0")
}
case []byte:
fmt.Print(string(v))
case time.Time:
fmt.Print(v.Format("2006-01-02 15:04:05.999"))
default:
fmt.Print(v)
}
}
1 change: 1 addition & 0 deletions integratedauth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type stubAuth struct {
func (s *stubAuth) InitialBytes() ([]byte, error) { return nil, nil }
func (s *stubAuth) NextBytes([]byte) ([]byte, error) { return nil, nil }
func (s *stubAuth) Free() {}
func (s *stubAuth) SetChannelBinding([]byte) {}

func getAuth(config msdsn.Config) (IntegratedAuthenticator, error) {
return &stubAuth{config.User}, nil
Expand Down
28 changes: 28 additions & 0 deletions integratedauth/channel_binding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package integratedauth

import (
"crypto/md5"
"encoding/binary"
)

func GenerateCBTFromTLSUnique(tlsUnique []byte) []byte {
// Initialize the channel binding structure with empty addresses
// These fields are defined in the RFC but not used for TLS bindings
initiatorAddress := make([]byte, 8)
acceptorAddress := make([]byte, 8)

// Create the application data with the "tls-unique:" prefix
applicationDataRaw := append([]byte("tls-unique:"), tlsUnique...)

// Add the length prefix to the application data (little-endian 32-bit integer)
lenApplicationData := make([]byte, 4)
binary.LittleEndian.PutUint32(lenApplicationData, uint32(len(applicationDataRaw)))
applicationData := append(lenApplicationData, applicationDataRaw...)

// Assemble the complete channel binding structure
channelBindingStruct := append(append(initiatorAddress, acceptorAddress...), applicationData...)

// Return the MD5 hash of the structure
hash := md5.Sum(channelBindingStruct)
return hash[:]
}
1 change: 1 addition & 0 deletions integratedauth/integratedauthenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type IntegratedAuthenticator interface {
InitialBytes() ([]byte, error)
NextBytes([]byte) ([]byte, error)
Free()
SetChannelBinding([]byte)
}

// ProviderFunc is an adapter to convert a GetIntegratedAuthenticator func into a Provider
Expand Down
5 changes: 5 additions & 0 deletions integratedauth/krb5/krb5.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ type krbAuth struct {
krb5Config *krb5Login
spnegoClient *spnego.SPNEGO
krb5Client *client.Client
channelBinding []byte
}

func (k *krbAuth) SetChannelBinding(channelBinding []byte) {
k.channelBinding = channelBinding
}

Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The channelBinding field is stored but never used in the Kerberos authentication flow. The field should either be integrated into the authentication logic (e.g., passed to the SPNEGO client) or removed if channel binding is not supported for Kerberos authentication in this implementation.

Suggested change
channelBinding []byte
}
func (k *krbAuth) SetChannelBinding(channelBinding []byte) {
k.channelBinding = channelBinding
}
}

Copilot uses AI. Check for mistakes.
func (k *krbAuth) InitialBytes() ([]byte, error) {
Expand Down
26 changes: 24 additions & 2 deletions integratedauth/ntlm/ntlm.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,20 @@ const _NEGOTIATE_FLAGS = _NEGOTIATE_UNICODE |
_NEGOTIATE_ALWAYS_SIGN |
_NEGOTIATE_EXTENDED_SESSIONSECURITY

const (
AV_PAIR_MsvAvChannelBindings = 0x000A
)

type Auth struct {
Domain string
UserName string
Password string
Workstation string
ChannelBinding []byte
}

func (auth *Auth) SetChannelBinding(channelBinding []byte) {
auth.ChannelBinding = channelBinding
}

// getAuth returns an authentication handle Auth to provide authentication content
Expand All @@ -76,6 +85,7 @@ func getAuth(config msdsn.Config) (integratedauth.IntegratedAuthenticator, error
UserName: domainUser[1],
Password: config.Password,
Workstation: config.Workstation,
ChannelBinding: []byte{},
}, nil
}

Expand Down Expand Up @@ -243,7 +253,7 @@ func getNTLMv2AndLMv2ResponsePayloads(userDomain, username, password string, cha
return
}

func negotiateExtendedSessionSecurity(flags uint32, message []byte, challenge [8]byte, username, password, userDom string) (lm, nt []byte, err error) {
func negotiateExtendedSessionSecurity(flags uint32, message []byte, challenge [8]byte, username, password, userDom string, channelBinding []byte) (lm, nt []byte, err error) {
nonce := clientChallenge()

// Official specification: https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-nlmp/b38c36ed-2804-4868-a9ff-8dd3182128e4
Expand All @@ -254,6 +264,18 @@ func negotiateExtendedSessionSecurity(flags uint32, message []byte, challenge [8
return lm, nt, err
}

if len(channelBinding) > 0 {
av_pair_cb := make([]byte, 4)
// AvId
// https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nlmp/83f5e789-660d-4781-8491-5f8c6641f75e
Comment on lines +269 to +270
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment "AvId" on line 269 is unclear and doesn't explain what this code is doing. Consider expanding this to explain that it's creating the AV_PAIR structure with the AvId field set to MsvAvChannelBindings and the AvLen field set to the length of the channel binding data, per the MS-NLMP specification.

Suggested change
// AvId
// https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nlmp/83f5e789-660d-4781-8491-5f8c6641f75e
// Create the AV_PAIR structure for channel bindings as specified in MS-NLMP.
// Set AvId to MsvAvChannelBindings and AvLen to the length of the channel binding data.
// See: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nlmp/83f5e789-660d-4781-8491-5f8c6641f75e

Copilot uses AI. Check for mistakes.
binary.LittleEndian.PutUint16(av_pair_cb[0:2], AV_PAIR_MsvAvChannelBindings)
binary.LittleEndian.PutUint16(av_pair_cb[2:4], uint16(len(channelBinding)))
av_pair_cb = append(av_pair_cb, channelBinding...)

targetInfoFields = append(targetInfoFields[:len(targetInfoFields)-4], av_pair_cb...)
targetInfoFields = append(targetInfoFields, 0,0,0,0)
}

nt, lm = getNTLMv2AndLMv2ResponsePayloads(userDom, username, password, challenge, nonce, targetInfoFields, time.Now())

return lm, nt, nil
Expand Down Expand Up @@ -376,7 +398,7 @@ func (auth *Auth) NextBytes(bytes []byte) ([]byte, error) {
copy(challenge[:], bytes[24:32])
flags := binary.LittleEndian.Uint32(bytes[20:24])
if (flags & _NEGOTIATE_EXTENDED_SESSIONSECURITY) != 0 {
lm, nt, err := negotiateExtendedSessionSecurity(flags, bytes, challenge, auth.UserName, auth.Password, auth.Domain)
lm, nt, err := negotiateExtendedSessionSecurity(flags, bytes, challenge, auth.UserName, auth.Password, auth.Domain, auth.ChannelBinding)
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions integratedauth/winsspi/winsspi.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ type Auth struct {
ctxt SecHandle
}

func (auth *Auth) SetChannelBinding(channelBinding []byte) {}

// getAuth returns an authentication handle Auth to provide authentication content
// to mssql.connect
func getAuth(config msdsn.Config) (integratedauth.IntegratedAuthenticator, error) {
Expand Down
18 changes: 14 additions & 4 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ const (
NoTraceID = "notraceid"
GuidConversion = "guid conversion"
Timezone = "timezone"
DisableEPA = "disableepa"
)

type EncodeParameters struct {
Expand All @@ -103,6 +104,7 @@ func (e EncodeParameters) GetTimezone() *time.Location {
}

type Config struct {
DisableEPA bool
Port uint64
Host string
Instance string
Expand Down Expand Up @@ -569,6 +571,14 @@ func Parse(dsn string) (Config, error) {
p.Encoding.GuidConversion = false
}

disableEPA, ok := params[DisableEPA]
if ok {
p.DisableEPA, err = strconv.ParseBool(disableEPA)
if err != nil {
return p, fmt.Errorf("invalid disableEPA '%s': %s", disableEPA, err.Error())
}
}

return p, nil
}

Expand Down Expand Up @@ -711,11 +721,11 @@ func splitAdoConnectionStringParts(dsn string) []string {
var parts []string
var current strings.Builder
inQuotes := false

runes := []rune(dsn)
for i := 0; i < len(runes); i++ {
char := runes[i]

if char == '"' {
if inQuotes && i+1 < len(runes) && runes[i+1] == '"' {
// Double quote escape sequence - add both quotes to current part
Expand All @@ -735,12 +745,12 @@ func splitAdoConnectionStringParts(dsn string) []string {
current.WriteRune(char)
}
}

// Add the last part if it's not empty
if current.Len() > 0 {
parts = append(parts, current.String())
}

return parts
}

Expand Down
Loading