diff --git a/examples/channel_binding/tsql.go b/examples/channel_binding/tsql.go new file mode 100644 index 00000000..639c7c48 --- /dev/null +++ b/examples/channel_binding/tsql.go @@ -0,0 +1,225 @@ +package main + +import ( + "bufio" + "context" + "crypto/tls" + "database/sql" + "flag" + "fmt" + "io" + "log" + "os" + "time" + + "github.com/google/uuid" + mssqldb "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/msdsn" + _ "github.com/microsoft/go-mssqldb/integratedauth/krb5" +) + +func main() { + var ( + userid = flag.String("U", "", "login_id") + password = flag.String("P", "", "password") + server = flag.String("S", "localhost", "server_name[\\instance_name]") + port = flag.Uint64("p", 1433, "server port") + keyLog = flag.String("K", "tlslog.log", "path to sslkeylog file") + database = flag.String("d", "", "db_name") + spn = flag.String("spn", "", "SPN") + auth = flag.String("a", "ntlm", "Authentication method: ntlm, krb5 or winsspi") + epa = flag.Bool("epa", true, "EPA enabled: true, false") + encrypt = flag.String("e", "required", "encrypt mode: required, disabled, strict, optional") + query = flag.String("q", "", "query to execute") + tlsMinVersion = flag.String("tlsmin", "1.1", "TLS minimum version: 1.0, 1.1, 1.2, 1.3") + tlsMaxVersion = flag.String("tlsmax", "1.3", "TLS maximum version: 1.0, 1.1, 1.2, 1.3") + ) + flag.Parse() + + keyLogFile, err := os.OpenFile(*keyLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err != nil { + log.Fatal("failed to open keylog file:", err) + } + defer func() { + if cerr := keyLogFile.Close(); cerr != nil { + log.Printf("warning: failed to close keylog file: %v", cerr) + } + }() + + encryption, err := parseEncrypt(*encrypt) + if err != nil { + log.Fatal("failed to parse encrypt: ", err) + } + + tlsMinVersionNum := msdsn.TLSVersionFromString(*tlsMinVersion) + tlsMaxVersionNum := msdsn.TLSVersionFromString(*tlsMaxVersion) + cfg := msdsn.Config{ + User: *userid, + Database: *database, + Host: *server, + Port: *port, + Password: *password, + ChangePassword: "", + AppName: "go-mssqldb", + ServerSPN: *spn, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, // adjust for your case + ServerName: *server, + KeyLogWriter: keyLogFile, + DynamicRecordSizingDisabled: true, + MinVersion: tlsMinVersionNum, + MaxVersion: tlsMaxVersionNum, + }, + Encryption: encryption, + Parameters: map[string]string{ + "authenticator": *auth, + "krb5-credcachefile": os.Getenv("KRB5_CCNAME"), + "krb5-configfile": os.Getenv("KRB5_CONFIG"), + }, + ProtocolParameters: map[string]interface{}{}, + Protocols: []string{ + "tcp", + }, + Encoding: msdsn.EncodeParameters{ + Timezone: time.UTC, + GuidConversion: false, + }, + DialTimeout: time.Second * 5, + ConnTimeout: time.Second * 10, + KeepAlive: time.Second * 30, + EpaEnabled: *epa, + } + + activityid, uerr := uuid.NewRandom() + if uerr == nil { + cfg.ActivityID = activityid[:] + } + + workstation, err := os.Hostname() + if err == nil { + cfg.Workstation = workstation + } + + connector := mssqldb.NewConnectorConfig(cfg) + + _, err = connector.Connect(context.Background()) + if err != nil { + fmt.Println("connector.Connect: ", err.Error()) + return + } + + db := sql.OpenDB(connector) + defer db.Close() + + err = db.Ping() + if err != nil { + fmt.Println("Cannot connect: ", err.Error()) + return + } + + if *query != "" { + err = exec(db, *query) + if err != nil { + fmt.Println(err) + } + return + } + + r := bufio.NewReader(os.Stdin) + for { + _, err = os.Stdout.Write([]byte("> ")) + if err != nil { + fmt.Println(err) + return + } + cmd, err := r.ReadString('\n') + if err != nil { + if err == io.EOF { + fmt.Println() + return + } + fmt.Println(err) + return + } + err = exec(db, cmd) + if err != nil { + fmt.Println(err) + } + } +} +func exec(db *sql.DB, cmd string) error { + rows, err := db.Query(cmd) + if err != nil { + return err + } + defer rows.Close() + cols, err := rows.Columns() + if err != nil { + return err + } + if cols == nil { + return nil + } + vals := make([]interface{}, len(cols)) + for i := 0; i < len(cols); i++ { + vals[i] = new(interface{}) + if i != 0 { + fmt.Print("\t") + } + fmt.Print(cols[i]) + } + fmt.Println() + for rows.Next() { + err = rows.Scan(vals...) + if err != nil { + fmt.Println(err) + continue + } + for i := 0; i < len(vals); i++ { + if i != 0 { + fmt.Print("\t") + } + printValue(vals[i].(*interface{})) + } + fmt.Println() + + } + if rows.Err() != nil { + return rows.Err() + } + return nil +} + +func printValue(pval *interface{}) { + switch v := (*pval).(type) { + case nil: + fmt.Print("NULL") + case bool: + if v { + fmt.Print("1") + } else { + fmt.Print("0") + } + case []byte: + fmt.Print(string(v)) + case time.Time: + fmt.Print(v.Format("2006-01-02 15:04:05.999")) + default: + fmt.Print(v) + } +} + +func parseEncrypt(encrypt string) (msdsn.Encryption, error) { + switch encrypt { + case "required", "yes", "1", "t", "true", "": + return msdsn.EncryptionRequired, nil + case "disabled": + return msdsn.EncryptionDisabled, nil + case "strict": + return msdsn.EncryptionStrict, nil + case "optional", "no", "0", "f", "false": + return msdsn.EncryptionOff, nil + default: + return msdsn.EncryptionOff, fmt.Errorf("invalid encrypt '%s'", encrypt) + } +} \ No newline at end of file diff --git a/integratedauth/auth_test.go b/integratedauth/auth_test.go index a3d25e30..265628c4 100644 --- a/integratedauth/auth_test.go +++ b/integratedauth/auth_test.go @@ -17,6 +17,7 @@ type stubAuth struct { func (s *stubAuth) InitialBytes() ([]byte, error) { return nil, nil } func (s *stubAuth) NextBytes([]byte) ([]byte, error) { return nil, nil } func (s *stubAuth) Free() {} +func (s *stubAuth) SetChannelBinding(*ChannelBindings) {} func getAuth(config msdsn.Config) (IntegratedAuthenticator, error) { return &stubAuth{config.User}, nil diff --git a/integratedauth/channel_binding.go b/integratedauth/channel_binding.go new file mode 100644 index 00000000..41234b1d --- /dev/null +++ b/integratedauth/channel_binding.go @@ -0,0 +1,230 @@ +package integratedauth + +import ( + "crypto" + "crypto/md5" + "crypto/tls" + "crypto/x509" + "encoding/binary" + "fmt" +) + +const ( + // https://datatracker.ietf.org/doc/rfc9266/ + TLS_EXPORTER_PREFIX = "tls-exporter:" + TLS_EXPORTER_EKM_LABEL = "EXPORTER-Channel-Binding" + TLS_EXPORTER_EKM_LENGTH = 32 + // https://www.rfc-editor.org/rfc/rfc5801.html#section-5.2 + TLS_UNIQUE_PREFIX = "tls-unique:" + TLS_SERVER_END_POINT_PREFIX = "tls-server-end-point:" +) + +// gss_channel_bindings_struct: https://docs.oracle.com/cd/E19683-01/816-1331/overview-52/index.html +// gss_buffer_desc: https://docs.oracle.com/cd/E19683-01/816-1331/reference-21/index.html +type ChannelBindings struct { + InitiatorAddrType uint32 + InitiatorAddress []byte + AcceptorAddrType uint32 + AcceptorAddress []byte + ApplicationData []byte +} + +// SEC_CHANNEL_BINDINGS: https://learn.microsoft.com/en-us/windows/win32/api/sspi/ns-sspi-sec_channel_bindings +type SEC_CHANNEL_BINDINGS struct { + DwInitiatorAddrType uint32 + CbInitiatorLength uint32 + DwInitiatorOffset uint32 + DwAcceptorAddrType uint32 + CbAcceptorLength uint32 + DwAcceptorOffset uint32 + CbApplicationDataLength uint32 + DwApplicationDataOffset uint32 + Data []byte +} + +// ToBytes converts a ChannelBindings struct to a byte slice as it would be gss_channel_bindings_struct structure in GSSAPI. +// Returns: +// - a byte slice +func (cb *ChannelBindings) ToBytes() []byte { + binarylength := 4 + 4 + 4 + 4 + 4 + uint32(len(cb.InitiatorAddress)+len(cb.AcceptorAddress)+len(cb.ApplicationData)) + i := 0 + bytes := make([]byte, binarylength) + binary.LittleEndian.PutUint32(bytes[i:i+4], cb.InitiatorAddrType) + i += 4 + binary.LittleEndian.PutUint32(bytes[i:i+4], uint32(len(cb.InitiatorAddress))) + i += 4 + if len(cb.InitiatorAddress) > 0 { + copy(bytes[i:i+len(cb.InitiatorAddress)], cb.InitiatorAddress) + i += len(cb.InitiatorAddress) + } + binary.LittleEndian.PutUint32(bytes[i:i+4], cb.AcceptorAddrType) + i += 4 + binary.LittleEndian.PutUint32(bytes[i:i+4], uint32(len(cb.AcceptorAddress))) + i += 4 + if len(cb.AcceptorAddress) > 0 { + copy(bytes[i:i+len(cb.AcceptorAddress)], cb.AcceptorAddress) + i += len(cb.AcceptorAddress) + } + binary.LittleEndian.PutUint32(bytes[i:i+4], uint32(len(cb.ApplicationData))) + i += 4 + if len(cb.ApplicationData) > 0 { + copy(bytes[i:i+len(cb.ApplicationData)], cb.ApplicationData) + i += len(cb.ApplicationData) + } + // Print bytes in hexdump -C style for debugging + return bytes +} + +// Md5Hash calculates the MD5 hash of the ChannelBindings struct +// Returns: +// - a byte slice +func (cb *ChannelBindings) Md5Hash() []byte { + hash := md5.New() + hash.Write(cb.ToBytes()) + return hash.Sum(nil) +} + +// AsSSPI_SEC_CHANNEL_BINDINGS converts a ChannelBindings struct to a SEC_CHANNEL_BINDINGS struct +// Returns: +// - a SEC_CHANNEL_BINDINGS struct +func (cb *ChannelBindings) AsSSPI_SEC_CHANNEL_BINDINGS() *SEC_CHANNEL_BINDINGS { + initiatorOffset := uint32(32) + acceptorOffset := initiatorOffset + uint32(len(cb.InitiatorAddress)) + applicationDataOffset := acceptorOffset + uint32(len(cb.AcceptorAddress)) + c := &SEC_CHANNEL_BINDINGS{ + DwInitiatorAddrType: cb.InitiatorAddrType, + CbInitiatorLength: uint32(len(cb.InitiatorAddress)), + DwInitiatorOffset: initiatorOffset, + DwAcceptorAddrType: cb.AcceptorAddrType, + CbAcceptorLength: uint32(len(cb.AcceptorAddress)), + DwAcceptorOffset: acceptorOffset, + CbApplicationDataLength: uint32(len(cb.ApplicationData)), + DwApplicationDataOffset: applicationDataOffset, + } + data := make([]byte, c.CbInitiatorLength+c.CbAcceptorLength+c.CbApplicationDataLength) + var i uint32 = 0 + if c.CbInitiatorLength > 0 { + copy(data[i:i+c.CbInitiatorLength], cb.InitiatorAddress) + i += c.CbInitiatorLength + } + if c.CbAcceptorLength > 0 { + copy(data[i:i+c.CbAcceptorLength], cb.AcceptorAddress) + i += c.CbAcceptorLength + } + if c.CbApplicationDataLength > 0 { + copy(data[i:i+c.CbApplicationDataLength], cb.ApplicationData) + i += c.CbApplicationDataLength + } + c.Data = data + return c +} + +// ToBytes converts a SEC_CHANNEL_BINDINGS struct to a byte slice, that can be use in SSPI InitializeSecurityContext function. +// Returns: +// - a byte slice +func (cb *SEC_CHANNEL_BINDINGS) ToBytes() []byte { + bytes := make([]byte, 32+len(cb.Data)) + binary.LittleEndian.PutUint32(bytes[0:4], cb.DwInitiatorAddrType) + binary.LittleEndian.PutUint32(bytes[4:8], cb.CbInitiatorLength) + binary.LittleEndian.PutUint32(bytes[8:12], cb.DwInitiatorOffset) + binary.LittleEndian.PutUint32(bytes[12:16], cb.DwAcceptorAddrType) + binary.LittleEndian.PutUint32(bytes[16:20], cb.CbAcceptorLength) + binary.LittleEndian.PutUint32(bytes[20:24], cb.DwAcceptorOffset) + binary.LittleEndian.PutUint32(bytes[24:28], cb.CbApplicationDataLength) + binary.LittleEndian.PutUint32(bytes[28:32], cb.DwApplicationDataOffset) + copy(bytes[32:32+len(cb.Data)], cb.Data) + + return bytes +} + +// GenerateCBTFromTLSUnique generates a ChannelBindings struct from a TLS unique value +// Adds tls-unique: prefix to the TLS unique value. +// Parameters: +// - tlsUnique: the TLS unique value +// Returns: +// - a ChannelBindings struct +func GenerateCBTFromTLSUnique(tlsUnique []byte) (*ChannelBindings, error) { + if len(tlsUnique) == 0 { + return nil, fmt.Errorf("tlsUnique is empty") + } + return &ChannelBindings{ + InitiatorAddrType: 0, + InitiatorAddress: nil, + AcceptorAddrType: 0, + AcceptorAddress: nil, + ApplicationData: append([]byte(TLS_UNIQUE_PREFIX), tlsUnique...), + }, nil +} + +// GenerateCBTFromTLSConnState generates a ChannelBindings struct from a TLS connection state +// If the TLS version is TLS 1.3, it generates a ChannelBindings struct from the TLS exporter key. +// If the TLS version is not TLS 1.3, it generates a ChannelBindings struct from the TLS unique value. +// Parameters: +// - state: the TLS connection state +// Returns: +// - a ChannelBindings struct +func GenerateCBTFromTLSConnState(state tls.ConnectionState) (*ChannelBindings, error) { + switch state.Version { + case tls.VersionTLS13: + exporterKey, err := state.ExportKeyingMaterial(TLS_EXPORTER_EKM_LABEL, nil, TLS_EXPORTER_EKM_LENGTH) + if err != nil { + return nil, fmt.Errorf("error exporting keying material: %w", err) + } + return GenerateCBTFromTLSExporter(exporterKey) + default: + return GenerateCBTFromTLSUnique(state.TLSUnique) + } +} + +// GenerateCBTFromTLSExporter generates a ChannelBindings struct from a TLS exporter key +// Parameters: +// - exporterKey: the TLS exporter key +// Returns: +// - a ChannelBindings struct +func GenerateCBTFromTLSExporter(exporterKey []byte) (*ChannelBindings, error) { + if len(exporterKey) == 0 { + return nil, fmt.Errorf("exporterKey is empty") + } + + return &ChannelBindings{ + InitiatorAddrType: 0, + InitiatorAddress: nil, + AcceptorAddrType: 0, + AcceptorAddress: nil, + ApplicationData: append([]byte(TLS_EXPORTER_PREFIX), exporterKey...), + }, nil +} + +// GenerateCBTFromServerCert generates a ChannelBindings struct from a server certificate +// Calculates the hash of the server certificate as described in 4.2 section of RFC5056. +// Parameters: +// - cert: the server certificate +// Returns: +// - a ChannelBindings struct +func GenerateCBTFromServerCert(cert *x509.Certificate) *ChannelBindings { + if cert == nil { + return nil + } + var certHash []byte + var hashType crypto.Hash + switch cert.SignatureAlgorithm { + case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.SHA256WithRSAPSS: + hashType = crypto.SHA256 + case x509.SHA384WithRSA, x509.ECDSAWithSHA384, x509.SHA384WithRSAPSS: + hashType = crypto.SHA384 + case x509.SHA512WithRSA, x509.ECDSAWithSHA512, x509.SHA512WithRSAPSS: + hashType = crypto.SHA512 + default: + hashType = crypto.SHA256 + } + h := hashType.New() + _, _ = h.Write(cert.Raw) + certHash = h.Sum(nil) + return &ChannelBindings{ + InitiatorAddrType: 0, + InitiatorAddress: nil, + AcceptorAddrType: 0, + AcceptorAddress: nil, + ApplicationData: append([]byte(TLS_SERVER_END_POINT_PREFIX), certHash...), + } +} diff --git a/integratedauth/integratedauthenticator.go b/integratedauth/integratedauthenticator.go index ce8240d7..4f3f38ad 100644 --- a/integratedauth/integratedauthenticator.go +++ b/integratedauth/integratedauthenticator.go @@ -15,6 +15,7 @@ type IntegratedAuthenticator interface { InitialBytes() ([]byte, error) NextBytes([]byte) ([]byte, error) Free() + SetChannelBinding(*ChannelBindings) } // ProviderFunc is an adapter to convert a GetIntegratedAuthenticator func into a Provider diff --git a/integratedauth/krb5/krb5.go b/integratedauth/krb5/krb5.go index 2dbbd41c..3cf65b6e 100644 --- a/integratedauth/krb5/krb5.go +++ b/integratedauth/krb5/krb5.go @@ -252,6 +252,11 @@ type krbAuth struct { krb5Config *krb5Login spnegoClient *spnego.SPNEGO krb5Client *client.Client + channelBinding *integratedauth.ChannelBindings +} + +func (k *krbAuth) SetChannelBinding(channelBinding *integratedauth.ChannelBindings) { + k.channelBinding = channelBinding } func (k *krbAuth) InitialBytes() ([]byte, error) { diff --git a/integratedauth/ntlm/ntlm.go b/integratedauth/ntlm/ntlm.go index d95032f2..2c7a35ec 100644 --- a/integratedauth/ntlm/ntlm.go +++ b/integratedauth/ntlm/ntlm.go @@ -57,11 +57,20 @@ const _NEGOTIATE_FLAGS = _NEGOTIATE_UNICODE | _NEGOTIATE_ALWAYS_SIGN | _NEGOTIATE_EXTENDED_SESSIONSECURITY +const ( + AV_PAIR_MsvAvChannelBindings = 0x000A +) + type Auth struct { - Domain string - UserName string - Password string - Workstation string + Domain string + UserName string + Password string + Workstation string + ChannelBinding []byte +} + +func (auth *Auth) SetChannelBinding(channelBinding *integratedauth.ChannelBindings) { + auth.ChannelBinding = channelBinding.Md5Hash() } // getAuth returns an authentication handle Auth to provide authentication content @@ -72,10 +81,11 @@ func getAuth(config msdsn.Config) (integratedauth.IntegratedAuthenticator, error } domainUser := strings.SplitN(config.User, "\\", 2) return &Auth{ - Domain: domainUser[0], - UserName: domainUser[1], - Password: config.Password, - Workstation: config.Workstation, + Domain: domainUser[0], + UserName: domainUser[1], + Password: config.Password, + Workstation: config.Workstation, + ChannelBinding: []byte{}, }, nil } @@ -243,7 +253,7 @@ func getNTLMv2AndLMv2ResponsePayloads(userDomain, username, password string, cha return } -func negotiateExtendedSessionSecurity(flags uint32, message []byte, challenge [8]byte, username, password, userDom string) (lm, nt []byte, err error) { +func negotiateExtendedSessionSecurity(flags uint32, message []byte, challenge [8]byte, username, password, userDom string, channelBinding []byte) (lm, nt []byte, err error) { nonce := clientChallenge() // Official specification: https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-nlmp/b38c36ed-2804-4868-a9ff-8dd3182128e4 @@ -254,6 +264,18 @@ func negotiateExtendedSessionSecurity(flags uint32, message []byte, challenge [8 return lm, nt, err } + if len(channelBinding) > 0 { + av_pair_cb := make([]byte, 4) + // AvId + // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nlmp/83f5e789-660d-4781-8491-5f8c6641f75e + binary.LittleEndian.PutUint16(av_pair_cb[0:2], AV_PAIR_MsvAvChannelBindings) + binary.LittleEndian.PutUint16(av_pair_cb[2:4], uint16(len(channelBinding))) + av_pair_cb = append(av_pair_cb, channelBinding...) + + targetInfoFields = append(targetInfoFields[:len(targetInfoFields)-4], av_pair_cb...) + targetInfoFields = append(targetInfoFields, 0, 0, 0, 0) + } + nt, lm = getNTLMv2AndLMv2ResponsePayloads(userDom, username, password, challenge, nonce, targetInfoFields, time.Now()) return lm, nt, nil @@ -376,7 +398,7 @@ func (auth *Auth) NextBytes(bytes []byte) ([]byte, error) { copy(challenge[:], bytes[24:32]) flags := binary.LittleEndian.Uint32(bytes[20:24]) if (flags & _NEGOTIATE_EXTENDED_SESSIONSECURITY) != 0 { - lm, nt, err := negotiateExtendedSessionSecurity(flags, bytes, challenge, auth.UserName, auth.Password, auth.Domain) + lm, nt, err := negotiateExtendedSessionSecurity(flags, bytes, challenge, auth.UserName, auth.Password, auth.Domain, auth.ChannelBinding) if err != nil { return nil, err } diff --git a/integratedauth/winsspi/winsspi.go b/integratedauth/winsspi/winsspi.go index 195d2288..0e49efac 100644 --- a/integratedauth/winsspi/winsspi.go +++ b/integratedauth/winsspi/winsspi.go @@ -26,6 +26,7 @@ func init() { const ( SEC_E_OK = 0 SECPKG_CRED_OUTBOUND = 2 + SECPKG_ATTR_UNIQUE_BINDINGS = 25 SEC_WINNT_AUTH_IDENTITY_UNICODE = 2 ISC_REQ_DELEGATE = 0x00000001 ISC_REQ_REPLAY_DETECT = 0x00000004 @@ -38,6 +39,7 @@ const ( SEC_I_COMPLETE_AND_CONTINUE = 0x00090314 SECBUFFER_VERSION = 0 SECBUFFER_TOKEN = 2 + SECBUFFER_CHANNEL_BINDINGS = 14 NTLMBUF_LEN = 12000 ) @@ -110,12 +112,22 @@ type SecBufferDesc struct { } type Auth struct { - Domain string - UserName string - Password string - Service string - cred SecHandle - ctxt SecHandle + Domain string + UserName string + Password string + Service string + cred SecHandle + ctxt SecHandle + channelBinding *integratedauth.SEC_CHANNEL_BINDINGS +} + +type SecPkgContext_Bindings struct { + BindingsLength uint64 + Bindings *byte +} + +func (auth *Auth) SetChannelBinding(channelBinding *integratedauth.ChannelBindings) { + auth.channelBinding = channelBinding.AsSSPI_SEC_CHANNEL_BINDINGS() } // getAuth returns an authentication handle Auth to provide authentication content @@ -133,6 +145,7 @@ func getAuth(config msdsn.Config) (integratedauth.IntegratedAuthenticator, error UserName: domainUser[1], Password: config.Password, Service: config.ServerSPN, + channelBinding: nil, }, nil } @@ -212,18 +225,33 @@ func (auth *Auth) InitialBytes() ([]byte, error) { func (auth *Auth) NextBytes(bytes []byte) ([]byte, error) { var in_buf, out_buf SecBuffer var in_desc, out_desc SecBufferDesc - - in_desc.ulVersion = SECBUFFER_VERSION - in_desc.cBuffers = 1 - in_desc.pBuffers = &in_buf + // Use fixed-size array instead of slice to ensure memory stability + var in_desc_buffers [2]SecBuffer + bufferCount := 0 out_desc.ulVersion = SECBUFFER_VERSION out_desc.cBuffers = 1 out_desc.pBuffers = &out_buf + // First buffer: input token in_buf.BufferType = SECBUFFER_TOKEN in_buf.pvBuffer = &bytes[0] in_buf.cbBuffer = uint32(len(bytes)) + in_desc_buffers[bufferCount] = in_buf + bufferCount++ + + // Second buffer: channel bindings (if present) + if auth.channelBinding != nil { + channelBindingBytes := auth.channelBinding.ToBytes() + in_desc_buffers[bufferCount].BufferType = SECBUFFER_CHANNEL_BINDINGS + in_desc_buffers[bufferCount].pvBuffer = &channelBindingBytes[0] + in_desc_buffers[bufferCount].cbBuffer = uint32(len(channelBindingBytes)) + bufferCount++ + } + + in_desc.ulVersion = SECBUFFER_VERSION + in_desc.cBuffers = uint32(bufferCount) + in_desc.pBuffers = &in_desc_buffers[0] outbuf := make([]byte, NTLMBUF_LEN) out_buf.BufferType = SECBUFFER_TOKEN diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 467e279d..f5570d16 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -22,6 +22,7 @@ type ( Encryption int Log uint64 BrowserMsg byte + EpaMode string ) const ( @@ -86,6 +87,7 @@ const ( NoTraceID = "notraceid" GuidConversion = "guid conversion" Timezone = "timezone" + EpaEnabled = "epa enabled" ) type EncodeParameters struct { @@ -159,6 +161,8 @@ type Config struct { NoTraceID bool // Parameters related to type encoding Encoding EncodeParameters + // EPA mode determines how the Channel Bindings are calculated. + EpaEnabled bool } func readDERFile(filename string) ([]byte, error) { @@ -569,6 +573,19 @@ func Parse(dsn string) (Config, error) { p.Encoding.GuidConversion = false } + epaString, ok := params[EpaEnabled] + if !ok { + epaString = os.Getenv("MSSQL_USE_EPA") + } + switch strings.ToLower(epaString) { + case "true", "1", "enabled", "yes", "y": + p.EpaEnabled = true + case "false", "0", "disabled", "no", "n": + p.EpaEnabled = false + default: + return p, fmt.Errorf("invalid epa enabled value '%s'", epaString) + } + return p, nil } @@ -711,11 +728,11 @@ func splitAdoConnectionStringParts(dsn string) []string { var parts []string var current strings.Builder inQuotes := false - + runes := []rune(dsn) for i := 0; i < len(runes); i++ { char := runes[i] - + if char == '"' { if inQuotes && i+1 < len(runes) && runes[i+1] == '"' { // Double quote escape sequence - add both quotes to current part @@ -735,12 +752,12 @@ func splitAdoConnectionStringParts(dsn string) []string { current.WriteRune(char) } } - + // Add the last part if it's not empty if current.Len() > 0 { parts = append(parts, current.String()) } - + return parts } diff --git a/tds.go b/tds.go index aaedaf71..d06ca4bf 100644 --- a/tds.go +++ b/tds.go @@ -1129,6 +1129,7 @@ func getTLSConn(conn *timeoutConn, p msdsn.Config, alpnSeq string) (tlsConn *tls } func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) { + var cbt *integratedauth.ChannelBindings isTransportEncrypted := false // if instance is specified use instance resolution service if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 { @@ -1172,11 +1173,18 @@ initiate_connection: outbuf := newTdsBuffer(packetSize, toconn) if p.Encryption == msdsn.EncryptionStrict { - outbuf.transport, err = getTLSConn(toconn, p, "tds/8.0") + tlsConn, err := getTLSConn(toconn, p, "tds/8.0") if err != nil { return nil, err } isTransportEncrypted = true + outbuf.transport = tlsConn + if p.EpaEnabled { + cbt, err = integratedauth.GenerateCBTFromTLSConnState(tlsConn.ConnectionState()) + if err != nil { + logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Error while generating Channel Bindings from TLS connection state: %v", err)) + } + } } sess := newSession(outbuf, logger, p) @@ -1253,8 +1261,14 @@ initiate_connection: outbuf.transport = toconn } } - } + if p.EpaEnabled { + cbt, err = integratedauth.GenerateCBTFromTLSConnState(tlsConn.ConnectionState()) + if err != nil { + logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Error while generating Channel Bindings from TLS connection state: %v", err)) + } + } + } } auth, err := integratedauth.GetIntegratedAuthenticator(p) @@ -1268,6 +1282,9 @@ initiate_connection: if auth != nil { defer auth.Free() + if cbt != nil { + auth.SetChannelBinding(cbt) + } } login, err := prepareLogin(ctx, c, p, logger, auth, fedAuth, uint32(outbuf.PackageSize()))