diff --git a/azuread/configuration.go b/azuread/configuration.go index f52cee3f..89909788 100644 --- a/azuread/configuration.go +++ b/azuread/configuration.go @@ -1,223 +1,238 @@ -//go:build go1.18 -// +build go1.18 - -package azuread - -import ( - "context" - "errors" - "fmt" - "os" - "strings" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - mssql "github.com/microsoft/go-mssqldb" - "github.com/microsoft/go-mssqldb/msdsn" -) - -const ( - ActiveDirectoryDefault = "ActiveDirectoryDefault" - ActiveDirectoryIntegrated = "ActiveDirectoryIntegrated" - ActiveDirectoryPassword = "ActiveDirectoryPassword" - ActiveDirectoryInteractive = "ActiveDirectoryInteractive" - // ActiveDirectoryMSI is a synonym for ActiveDirectoryManagedIdentity - ActiveDirectoryMSI = "ActiveDirectoryMSI" - ActiveDirectoryManagedIdentity = "ActiveDirectoryManagedIdentity" - // ActiveDirectoryApplication is a synonym for ActiveDirectoryServicePrincipal - ActiveDirectoryApplication = "ActiveDirectoryApplication" - ActiveDirectoryServicePrincipal = "ActiveDirectoryServicePrincipal" - ActiveDirectoryServicePrincipalAccessToken = "ActiveDirectoryServicePrincipalAccessToken" - scopeDefaultSuffix = "/.default" -) - -type azureFedAuthConfig struct { - adalWorkflow byte - mssqlConfig msdsn.Config - // The detected federated authentication library - fedAuthLibrary int - fedAuthWorkflow string - // Service principal logins - clientID string - tenantID string - clientSecret string - certificatePath string - resourceID string - - // AD password/managed identity/interactive - user string - password string - applicationClientID string -} - -// parse returns a config based on an msdsn-style connection string -func parse(dsn string) (*azureFedAuthConfig, error) { - mssqlConfig, err := msdsn.Parse(dsn) - if err != nil { - return nil, err - } - config := &azureFedAuthConfig{ - fedAuthLibrary: mssql.FedAuthLibraryReserved, - mssqlConfig: mssqlConfig, - } - - err = config.validateParameters(mssqlConfig.Parameters) - if err != nil { - return nil, err - } - - return config, nil -} - -func (p *azureFedAuthConfig) validateParameters(params map[string]string) error { - - fedAuthWorkflow, _ := params["fedauth"] - if fedAuthWorkflow == "" { - return nil - } - - p.fedAuthLibrary = mssql.FedAuthLibraryADAL - - p.applicationClientID, _ = params["applicationclientid"] - - switch { - case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryPassword): - if p.applicationClientID == "" { - return errors.New("applicationclientid parameter is required for " + ActiveDirectoryPassword) - } - p.adalWorkflow = mssql.FedAuthADALWorkflowPassword - p.user, _ = params["user id"] - p.password, _ = params["password"] - case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryIntegrated): - // Active Directory Integrated authentication is not fully supported: - // you can only use this by also implementing an a token provider - // and supplying it via ActiveDirectoryTokenProvider in the Connection. - p.adalWorkflow = mssql.FedAuthADALWorkflowIntegrated - case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryManagedIdentity) || strings.EqualFold(fedAuthWorkflow, ActiveDirectoryMSI): - // When using MSI, to request a specific client ID or user-assigned identity, - // provide the ID in the "user id" parameter - p.adalWorkflow = mssql.FedAuthADALWorkflowMSI - p.resourceID, _ = params["resource id"] - p.clientID, _ = splitTenantAndClientID(params["user id"]) - case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryApplication) || strings.EqualFold(fedAuthWorkflow, ActiveDirectoryServicePrincipal): - p.adalWorkflow = mssql.FedAuthADALWorkflowPassword - // Split the clientID@tenantID format - // If no tenant is provided we'll use the one from the server - p.clientID, p.tenantID = splitTenantAndClientID(params["user id"]) - if p.clientID == "" { - return errors.New("Must provide 'client id[@tenant id]' as username parameter when using ActiveDirectoryApplication authentication") - } - - p.clientSecret, _ = params["password"] - - p.certificatePath, _ = params["clientcertpath"] - - if p.certificatePath == "" && p.clientSecret == "" { - return errors.New("Must provide 'password' parameter when using ActiveDirectoryApplication authentication without cert/key credentials") - } - case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryDefault): - p.adalWorkflow = mssql.FedAuthADALWorkflowPassword - case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryInteractive): - if p.applicationClientID == "" { - return errors.New("applicationclientid parameter is required for " + ActiveDirectoryInteractive) - } - p.adalWorkflow = mssql.FedAuthADALWorkflowPassword - // user is an optional login hint - p.user, _ = params["user id"] - // we don't really have a password but we need to use some value. - p.adalWorkflow = mssql.FedAuthADALWorkflowPassword - case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryServicePrincipalAccessToken): - p.fedAuthLibrary = mssql.FedAuthLibrarySecurityToken - p.adalWorkflow = mssql.FedAuthADALWorkflowNone - p.password, _ = params["password"] - - if p.password == "" { - return errors.New("Must provide 'password' parameter when using ActiveDirectoryApplicationAuthToken authentication") - } - default: - return fmt.Errorf("Invalid federated authentication type '%s': expected one of %+v", - fedAuthWorkflow, - []string{ActiveDirectoryApplication, ActiveDirectoryServicePrincipal, ActiveDirectoryDefault, ActiveDirectoryIntegrated, ActiveDirectoryInteractive, ActiveDirectoryManagedIdentity, ActiveDirectoryMSI, ActiveDirectoryPassword}) - } - p.fedAuthWorkflow = fedAuthWorkflow - return nil -} - -func splitTenantAndClientID(user string) (string, string) { - // Split the user name into client id and tenant id at the @ symbol - at := strings.IndexRune(user, '@') - if at < 1 || at >= (len(user)-1) { - return user, "" - } - - return user[0:at], user[at+1:] -} - -func splitAuthorityAndTenant(authorityUrl string) (string, string) { - separatorIndex := strings.LastIndex(authorityUrl, "/") - tenant := authorityUrl[separatorIndex+1:] - authority := authorityUrl[:separatorIndex] - return authority, tenant -} - -func (p *azureFedAuthConfig) provideActiveDirectoryToken(ctx context.Context, serverSPN, stsURL string) (string, error) { - var cred azcore.TokenCredential - var err error - authority, tenant := splitAuthorityAndTenant(stsURL) - // client secret connection strings may override the server tenant - if p.tenantID != "" { - tenant = p.tenantID - } - scope := stsURL - if !strings.HasSuffix(serverSPN, scopeDefaultSuffix) { - scope = strings.TrimRight(serverSPN, "/") + scopeDefaultSuffix - } - - switch p.fedAuthWorkflow { - case ActiveDirectoryServicePrincipal, ActiveDirectoryApplication: - switch { - case p.certificatePath != "": - certData, err := os.ReadFile(p.certificatePath) - if err != nil { - certs, key, err := azidentity.ParseCertificates(certData, []byte(p.clientSecret)) - if err != nil { - cred, err = azidentity.NewClientCertificateCredential(tenant, p.clientID, certs, key, nil) - } - } - default: - cred, err = azidentity.NewClientSecretCredential(tenant, p.clientID, p.clientSecret, nil) - } - case ActiveDirectoryServicePrincipalAccessToken: - return p.password, nil - case ActiveDirectoryPassword: - cred, err = azidentity.NewUsernamePasswordCredential(tenant, p.applicationClientID, p.user, p.password, nil) - case ActiveDirectoryMSI, ActiveDirectoryManagedIdentity: - if p.resourceID != "" { - cred, err = azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ResourceID(p.resourceID)}) - } else if p.clientID != "" { - cred, err = azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ClientID(p.clientID)}) - } else { - cred, err = azidentity.NewManagedIdentityCredential(nil) - } - case ActiveDirectoryInteractive: - c := cloud.Configuration{ActiveDirectoryAuthorityHost: authority} - config := azcore.ClientOptions{Cloud: c} - cred, err = azidentity.NewInteractiveBrowserCredential(&azidentity.InteractiveBrowserCredentialOptions{ClientOptions: config, ClientID: p.applicationClientID}) - - default: - // Integrated just uses Default until azidentity adds Windows-specific authentication - cred, err = azidentity.NewDefaultAzureCredential(nil) - } - - if err != nil { - return "", err - } - opts := policy.TokenRequestOptions{Scopes: []string{scope}} - tk, err := cred.GetToken(ctx, opts) - if err != nil { - return "", err - } - return tk.Token, err -} +//go:build go1.18 +// +build go1.18 + +package azuread + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/msdsn" +) + +const ( + ActiveDirectoryDefault = "ActiveDirectoryDefault" + ActiveDirectoryIntegrated = "ActiveDirectoryIntegrated" + ActiveDirectoryPassword = "ActiveDirectoryPassword" + ActiveDirectoryInteractive = "ActiveDirectoryInteractive" + // ActiveDirectoryMSI is a synonym for ActiveDirectoryManagedIdentity + ActiveDirectoryMSI = "ActiveDirectoryMSI" + ActiveDirectoryManagedIdentity = "ActiveDirectoryManagedIdentity" + // ActiveDirectoryApplication is a synonym for ActiveDirectoryServicePrincipal + ActiveDirectoryApplication = "ActiveDirectoryApplication" + ActiveDirectoryServicePrincipal = "ActiveDirectoryServicePrincipal" + ActiveDirectoryServicePrincipalAccessToken = "ActiveDirectoryServicePrincipalAccessToken" + scopeDefaultSuffix = "/.default" +) + +type azureFedAuthConfig struct { + adalWorkflow byte + mssqlConfig msdsn.Config + // The detected federated authentication library + fedAuthLibrary int + fedAuthWorkflow string + // Service principal logins + clientID string + tenantID string + clientSecret string + certificatePath string + resourceID string + + // AD password/managed identity/interactive + user string + password string + applicationClientID string +} + +// parse returns a config based on an msdsn-style connection string +func parse(dsn string) (*azureFedAuthConfig, error) { + mssqlConfig, err := msdsn.Parse(dsn) + if err != nil { + return nil, err + } + config := &azureFedAuthConfig{ + fedAuthLibrary: mssql.FedAuthLibraryReserved, + mssqlConfig: mssqlConfig, + } + + err = config.validateParameters(mssqlConfig.Parameters) + if err != nil { + return nil, err + } + + return config, nil +} + +// newConfig returns a config based on an msdsn.Config and a map of parameters. +func newConfig(dsnConfig msdsn.Config, params map[string]string) (*azureFedAuthConfig, error) { + config := &azureFedAuthConfig{ + fedAuthLibrary: mssql.FedAuthLibraryReserved, + mssqlConfig: dsnConfig, + } + + err := config.validateParameters(params) + if err != nil { + return nil, err + } + + return config, nil +} + +func (p *azureFedAuthConfig) validateParameters(params map[string]string) error { + + fedAuthWorkflow, _ := params["fedauth"] + if fedAuthWorkflow == "" { + return nil + } + + p.fedAuthLibrary = mssql.FedAuthLibraryADAL + + p.applicationClientID, _ = params["applicationclientid"] + + switch { + case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryPassword): + if p.applicationClientID == "" { + return errors.New("applicationclientid parameter is required for " + ActiveDirectoryPassword) + } + p.adalWorkflow = mssql.FedAuthADALWorkflowPassword + p.user, _ = params["user id"] + p.password, _ = params["password"] + case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryIntegrated): + // Active Directory Integrated authentication is not fully supported: + // you can only use this by also implementing an a token provider + // and supplying it via ActiveDirectoryTokenProvider in the Connection. + p.adalWorkflow = mssql.FedAuthADALWorkflowIntegrated + case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryManagedIdentity) || strings.EqualFold(fedAuthWorkflow, ActiveDirectoryMSI): + // When using MSI, to request a specific client ID or user-assigned identity, + // provide the ID in the "user id" parameter + p.adalWorkflow = mssql.FedAuthADALWorkflowMSI + p.resourceID, _ = params["resource id"] + p.clientID, _ = splitTenantAndClientID(params["user id"]) + case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryApplication) || strings.EqualFold(fedAuthWorkflow, ActiveDirectoryServicePrincipal): + p.adalWorkflow = mssql.FedAuthADALWorkflowPassword + // Split the clientID@tenantID format + // If no tenant is provided we'll use the one from the server + p.clientID, p.tenantID = splitTenantAndClientID(params["user id"]) + if p.clientID == "" { + return errors.New("Must provide 'client id[@tenant id]' as username parameter when using ActiveDirectoryApplication authentication") + } + + p.clientSecret, _ = params["password"] + + p.certificatePath, _ = params["clientcertpath"] + + if p.certificatePath == "" && p.clientSecret == "" { + return errors.New("Must provide 'password' parameter when using ActiveDirectoryApplication authentication without cert/key credentials") + } + case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryDefault): + p.adalWorkflow = mssql.FedAuthADALWorkflowPassword + case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryInteractive): + if p.applicationClientID == "" { + return errors.New("applicationclientid parameter is required for " + ActiveDirectoryInteractive) + } + p.adalWorkflow = mssql.FedAuthADALWorkflowPassword + // user is an optional login hint + p.user, _ = params["user id"] + // we don't really have a password but we need to use some value. + p.adalWorkflow = mssql.FedAuthADALWorkflowPassword + case strings.EqualFold(fedAuthWorkflow, ActiveDirectoryServicePrincipalAccessToken): + p.fedAuthLibrary = mssql.FedAuthLibrarySecurityToken + p.adalWorkflow = mssql.FedAuthADALWorkflowNone + p.password, _ = params["password"] + + if p.password == "" { + return errors.New("Must provide 'password' parameter when using ActiveDirectoryApplicationAuthToken authentication") + } + default: + return fmt.Errorf("Invalid federated authentication type '%s': expected one of %+v", + fedAuthWorkflow, + []string{ActiveDirectoryApplication, ActiveDirectoryServicePrincipal, ActiveDirectoryDefault, ActiveDirectoryIntegrated, ActiveDirectoryInteractive, ActiveDirectoryManagedIdentity, ActiveDirectoryMSI, ActiveDirectoryPassword}) + } + p.fedAuthWorkflow = fedAuthWorkflow + return nil +} + +func splitTenantAndClientID(user string) (string, string) { + // Split the user name into client id and tenant id at the @ symbol + at := strings.IndexRune(user, '@') + if at < 1 || at >= (len(user)-1) { + return user, "" + } + + return user[0:at], user[at+1:] +} + +func splitAuthorityAndTenant(authorityUrl string) (string, string) { + separatorIndex := strings.LastIndex(authorityUrl, "/") + tenant := authorityUrl[separatorIndex+1:] + authority := authorityUrl[:separatorIndex] + return authority, tenant +} + +func (p *azureFedAuthConfig) provideActiveDirectoryToken(ctx context.Context, serverSPN, stsURL string) (string, error) { + var cred azcore.TokenCredential + var err error + authority, tenant := splitAuthorityAndTenant(stsURL) + // client secret connection strings may override the server tenant + if p.tenantID != "" { + tenant = p.tenantID + } + scope := stsURL + if !strings.HasSuffix(serverSPN, scopeDefaultSuffix) { + scope = strings.TrimRight(serverSPN, "/") + scopeDefaultSuffix + } + + switch p.fedAuthWorkflow { + case ActiveDirectoryServicePrincipal, ActiveDirectoryApplication: + switch { + case p.certificatePath != "": + certData, err := os.ReadFile(p.certificatePath) + if err != nil { + certs, key, err := azidentity.ParseCertificates(certData, []byte(p.clientSecret)) + if err != nil { + cred, err = azidentity.NewClientCertificateCredential(tenant, p.clientID, certs, key, nil) + } + } + default: + cred, err = azidentity.NewClientSecretCredential(tenant, p.clientID, p.clientSecret, nil) + } + case ActiveDirectoryServicePrincipalAccessToken: + return p.password, nil + case ActiveDirectoryPassword: + cred, err = azidentity.NewUsernamePasswordCredential(tenant, p.applicationClientID, p.user, p.password, nil) + case ActiveDirectoryMSI, ActiveDirectoryManagedIdentity: + if p.resourceID != "" { + cred, err = azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ResourceID(p.resourceID)}) + } else if p.clientID != "" { + cred, err = azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ClientID(p.clientID)}) + } else { + cred, err = azidentity.NewManagedIdentityCredential(nil) + } + case ActiveDirectoryInteractive: + c := cloud.Configuration{ActiveDirectoryAuthorityHost: authority} + config := azcore.ClientOptions{Cloud: c} + cred, err = azidentity.NewInteractiveBrowserCredential(&azidentity.InteractiveBrowserCredentialOptions{ClientOptions: config, ClientID: p.applicationClientID}) + + default: + // Integrated just uses Default until azidentity adds Windows-specific authentication + cred, err = azidentity.NewDefaultAzureCredential(nil) + } + + if err != nil { + return "", err + } + opts := policy.TokenRequestOptions{Scopes: []string{scope}} + tk, err := cred.GetToken(ctx, opts) + if err != nil { + return "", err + } + return tk.Token, err +} diff --git a/azuread/driver.go b/azuread/driver.go index ccd9d824..11ec6bcb 100644 --- a/azuread/driver.go +++ b/azuread/driver.go @@ -7,8 +7,8 @@ import ( "context" "database/sql" "database/sql/driver" - mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/msdsn" ) // DriverName is the name used to register the driver @@ -43,6 +43,16 @@ func NewConnector(dsn string) (*mssql.Connector, error) { return newConnectorConfig(config) } +// NewConnectorFromConfig returns a new connector with the provided configuration and additional parameters +func NewConnectorFromConfig(dsnConfig msdsn.Config, params map[string]string) (*mssql.Connector, error) { + config, err := newConfig(dsnConfig, params) + if err != nil { + return nil, err + } + + return newConnectorConfig(config) +} + // newConnectorConfig creates a Connector from config. func newConnectorConfig(config *azureFedAuthConfig) (*mssql.Connector, error) { switch config.fedAuthLibrary { diff --git a/buf.go b/buf.go index 68663421..7ec331d2 100644 --- a/buf.go +++ b/buf.go @@ -28,6 +28,8 @@ var bufpool = sync.Pool{ }, } +type TDSBuffer = tdsBuffer + // tdsBuffer reads and writes TDS packets of data to the transport. // The write and read buffers are separate to make sending attn signals // possible without locks. Currently attn signals are only sent during @@ -59,6 +61,14 @@ type tdsBuffer struct { afterFirst func() } +// NewTdsBuffer returns an exported version of *tdsBuffer +func NewTdsBuffer(buff []byte, size int) *TDSBuffer { + return &tdsBuffer{ + rbuf: buff, + rsize: size, + } +} + func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer { // pull an existing buf if one is available or get and add a new buf to the bufpool diff --git a/error.go b/error.go index e60288a6..311884e9 100644 --- a/error.go +++ b/error.go @@ -1,7 +1,9 @@ package mssql import ( + "bytes" "database/sql/driver" + "encoding/binary" "fmt" ) @@ -23,6 +25,46 @@ type Error struct { All []Error } +// Marshal marshals the error to the wire protocol token. +func (e *Error) Marshal() ([]byte, error) { + buf := bytes.NewBuffer([]byte{ + byte(tokenError), + }) + length := 2 + // length + 4 + // number + 1 + // state + 1 + // class + (2 + 2*len(e.Message)) + // message + (1 + 2*len(e.ServerName)) + // server name + (1 + 2*len(e.ProcName)) + // proc name + 4 // line no + if err := binary.Write(buf, binary.LittleEndian, uint16(length)); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.LittleEndian, e.Number); err != nil { + return nil, err + } + if err := buf.WriteByte(e.State); err != nil { + return nil, err + } + if err := buf.WriteByte(e.Class); err != nil { + return nil, err + } + if err := writeUsVarChar(buf, e.Message); err != nil { + return nil, err + } + if err := writeBVarChar(buf, e.ServerName); err != nil { + return nil, err + } + if err := writeBVarChar(buf, e.ProcName); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.LittleEndian, e.LineNo); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + func (e Error) Error() string { return "mssql: " + e.Message } diff --git a/mssql.go b/mssql.go index 28288a4f..6880f712 100644 --- a/mssql.go +++ b/mssql.go @@ -23,9 +23,9 @@ import ( // ReturnStatus may be used to return the return value from a proc. // -// var rs mssql.ReturnStatus -// _, err := db.Exec("theproc", &rs) -// log.Printf("return status = %d", rs) +// var rs mssql.ReturnStatus +// _, err := db.Exec("theproc", &rs) +// log.Printf("return status = %d", rs) type ReturnStatus int32 var driverInstance = &Driver{processQueryText: true} @@ -150,6 +150,12 @@ func NewConnectorConfig(config msdsn.Config) *Connector { } } +type auth interface { + InitialBytes() ([]byte, error) + NextBytes([]byte) ([]byte, error) + Free() +} + // Connector holds the parsed DSN and is ready to make a new connection // at any time. // @@ -169,6 +175,9 @@ type Connector struct { // callback that can provide a security token during ADAL login adalTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error) + // auth allows to provide a custom authenticator. + auth auth + // SessionInitSQL is executed after marking a given session to be reset. // When not present, the next query will still reset the session to the // database defaults. @@ -231,6 +240,16 @@ func (c *Conn) IsValid() bool { return c.connectionGood } +// GetUnderlyingConn returns underlying raw server connection. +func (c *Conn) GetUnderlyingConn() io.ReadWriteCloser { + return c.sess.buf.transport +} + +// GetLoginFlags returns tokens returned by server during login handshake. +func (c *Conn) GetLoginFlags() []Token { + return c.sess.loginFlags +} + // checkBadConn marks the connection as bad based on the characteristics // of the supplied error. Bad connections will be dropped from the connection // pool rather than reused. @@ -878,12 +897,13 @@ func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { // not a variable length type ok should return false. // If length is not limited other than system limits, it should return math.MaxInt64. // The following are examples of returned values for various types: -// TEXT (math.MaxInt64, true) -// varchar(10) (10, true) -// nvarchar(10) (10, true) -// decimal (0, false) -// int (0, false) -// bytea(30) (30, true) +// +// TEXT (math.MaxInt64, true) +// varchar(10) (10, true) +// nvarchar(10) (10, true) +// decimal (0, false) +// int (0, false) +// bytea(30) (30, true) func (r *Rows) ColumnTypeLength(index int) (int64, bool) { return makeGoLangTypeLength(r.cols[index].ti) } @@ -891,9 +911,10 @@ func (r *Rows) ColumnTypeLength(index int) (int64, bool) { // It should return // the precision and scale for decimal types. If not applicable, ok should be false. // The following are examples of returned values for various types: -// decimal(38, 4) (38, 4, true) -// int (0, 0, false) -// decimal (math.MaxInt64, math.MaxInt64, true) +// +// decimal(38, 4) (38, 4, true) +// int (0, 0, false) +// decimal (math.MaxInt64, math.MaxInt64, true) func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) { return makeGoLangTypePrecisionScale(r.cols[index].ti) } @@ -1320,12 +1341,13 @@ func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string { // not a variable length type ok should return false. // If length is not limited other than system limits, it should return math.MaxInt64. // The following are examples of returned values for various types: -// TEXT (math.MaxInt64, true) -// varchar(10) (10, true) -// nvarchar(10) (10, true) -// decimal (0, false) -// int (0, false) -// bytea(30) (30, true) +// +// TEXT (math.MaxInt64, true) +// varchar(10) (10, true) +// nvarchar(10) (10, true) +// decimal (0, false) +// int (0, false) +// bytea(30) (30, true) func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) { return makeGoLangTypeLength(r.cols[index].ti) } @@ -1333,9 +1355,10 @@ func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) { // It should return // the precision and scale for decimal types. If not applicable, ok should be false. // The following are examples of returned values for various types: -// decimal(38, 4) (38, 4, true) -// int (0, 0, false) -// decimal (math.MaxInt64, math.MaxInt64, true) +// +// decimal(38, 4) (38, 4, true) +// int (0, 0, false) +// decimal (math.MaxInt64, math.MaxInt64, true) func (r *Rowsq) ColumnTypePrecisionScale(index int) (int64, int64, bool) { return makeGoLangTypePrecisionScale(r.cols[index].ti) } diff --git a/tds.go b/tds.go index abbe063d..f090e9b5 100644 --- a/tds.go +++ b/tds.go @@ -143,6 +143,7 @@ type tdsSession struct { logger ContextLogger routedServer string routedPort uint16 + loginFlags []Token } const ( @@ -168,9 +169,23 @@ func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // http://msdn.microsoft.com/en-us/library/dd357559.aspx func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte) error { + w.BeginPacket(packetType, false) + if err := WritePreLoginFields(w, fields); err != nil { + return err + } + return w.FinishPacket() +} + +// Writer is an interface that combines Writer and ByteWriter. +type Writer interface { + io.Writer + io.ByteWriter +} + +// WritePreLoginFields writes provided Pre-Login packet fields into the writer. +func WritePreLoginFields(w Writer, fields map[uint8][]byte) error { var err error - w.BeginPacket(packetType, false) offset := uint16(5*len(fields) + 1) keys := make(keySlice, 0, len(fields)) for k := range fields { @@ -210,7 +225,7 @@ func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte) return errors.New("Write method didn't write the whole value") } } - return w.FinishPacket() + return nil } func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) { @@ -1195,6 +1210,15 @@ initiate_connection: break } + // Save options returned by the server so callers implementing + // proxies can pass them back to the original client. + switch tok.(type) { + case envChangeStruct, loginAckStruct, doneStruct: + if token, ok := tok.(Token); ok { + sess.loginFlags = append(sess.loginFlags, token) + } + } + switch token := tok.(type) { case sspiMsg: sspi_msg, err := auth.NextBytes(token) diff --git a/token.go b/token.go index 76d4e025..90d32a96 100644 --- a/token.go +++ b/token.go @@ -1,6 +1,7 @@ package mssql import ( + "bytes" "context" "encoding/binary" "fmt" @@ -102,10 +103,20 @@ const ( // interface for all tokens type tokenStruct interface{} +// Token represents a token that can be marshaled to wire representation. +type Token interface { + Marshal() ([]byte, error) +} + type orderStruct struct { ColIds []uint16 } +// DoneToken returns a Done token. +func DoneToken() Token { + return doneStruct{} +} + type doneStruct struct { Status uint16 CurCmd uint16 @@ -113,6 +124,23 @@ type doneStruct struct { errors []Error } +// Marshal returns the token's wire protocol representation. +func (d doneStruct) Marshal() ([]byte, error) { + buf := bytes.NewBuffer([]byte{ + byte(tokenDone), + }) + if err := binary.Write(buf, binary.LittleEndian, d.Status); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.LittleEndian, d.CurCmd); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.LittleEndian, d.RowCount); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + func (d doneStruct) isError() bool { return d.Status&doneError != 0 || len(d.errors) > 0 } @@ -131,17 +159,35 @@ func (d doneStruct) getError() Error { type doneInProcStruct doneStruct +type envChangeStruct struct { + bytes []byte +} + +// Marshal returns the token's wire protocol representation. +func (e envChangeStruct) Marshal() ([]byte, error) { + return e.bytes, nil +} + // ENVCHANGE stream // http://msdn.microsoft.com/en-us/library/dd303449.aspx -func processEnvChg(ctx context.Context, sess *tdsSession) { +func processEnvChg(ctx context.Context, sess *tdsSession) envChangeStruct { + buf := bytes.NewBuffer([]byte{ + byte(tokenEnvChange), + }) size := sess.buf.uint16() - r := &io.LimitedReader{R: sess.buf, N: int64(size)} + if err := binary.Write(buf, binary.LittleEndian, size); err != nil { + badStreamPanic(err) + } + // Duplicate the token stream in the buffer. + r := io.TeeReader(&io.LimitedReader{R: sess.buf, N: int64(size)}, buf) for { var err error var envtype uint8 err = binary.Read(r, binary.LittleEndian, &envtype) if err == io.EOF { - return + return envChangeStruct{ + bytes: buf.Bytes(), + } } if err != nil { badStreamPanic(err) @@ -152,8 +198,7 @@ func processEnvChg(ctx context.Context, sess *tdsSession) { if err != nil { badStreamPanic(err) } - _, err = readBVarChar(r) - if err != nil { + if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTypLanguage: @@ -181,8 +226,7 @@ func processEnvChg(ctx context.Context, sess *tdsSession) { if err != nil { badStreamPanic(err) } - _, err = readBVarChar(r) - if err != nil { + if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } packetsizei, err := strconv.Atoi(packetsize) @@ -217,7 +261,6 @@ func processEnvChg(ctx context.Context, sess *tdsSession) { if err != nil { badStreamPanic(err) } - // SQL Collation data should contain 5 bytes in length if collationSize != 5 { badStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize) @@ -229,16 +272,14 @@ func processEnvChg(ctx context.Context, sess *tdsSession) { if err != nil { badStreamPanic(err) } - // 1 byte, contains: sortID var sortID uint8 err = binary.Read(r, binary.LittleEndian, &sortID) if err != nil { badStreamPanic(err) } - // old value, should be 0 - if _, err = readBVarChar(r); err != nil { + if _, err := readBVarChar(r); err != nil { badStreamPanic(err) } case envTypBeginTran: @@ -388,7 +429,9 @@ func processEnvChg(ctx context.Context, sess *tdsSession) { if sess.logFlags&logDebug != 0 { sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("WARN: Unknown ENVCHANGE record detected with type id = %d", envtype)) } - return + return envChangeStruct{ + bytes: buf.Bytes(), + } } } } @@ -797,7 +840,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS parseNbcRow(sess.buf, columns, row) ch <- row case tokenEnvChange: - processEnvChg(ctx, sess) + ch <- processEnvChg(ctx, sess) case tokenError: err := parseError72(sess.buf) if sess.logFlags&logDebug != 0 { diff --git a/types.go b/types.go index 3b4760e3..00b40caf 100644 --- a/types.go +++ b/types.go @@ -85,6 +85,8 @@ const ( fDefault = 0x200 ) +type TypeInfo = typeInfo + // TYPE_INFO rule // http://msdn.microsoft.com/en-us/library/dd358284.aspx type typeInfo struct { @@ -119,6 +121,11 @@ type xmlInfo struct { XmlSchemaCollection string } +// ReadTypeInfo returns an exported version of typeInfo +func ReadTypeInfo(r *TDSBuffer) TypeInfo { + return readTypeInfo(r) +} + func readTypeInfo(r *tdsBuffer) (res typeInfo) { res.TypeId = r.byte() switch res.TypeId { diff --git a/ucs22str.go b/ucs22str.go index 9c9c66eb..387e7b64 100644 --- a/ucs22str.go +++ b/ucs22str.go @@ -149,3 +149,8 @@ func ucs22str(s []byte) (string, error) { // After this point both s and uint16slice can be garbage collected. return string(utf16.Decode(uint16slice)), nil } + +// ParseUCS2String returns string from its UCS-2 encoded representation. +func ParseUCS2String(s []byte) (string, error) { + return ucs22str(s) +}