@@ -32,22 +32,26 @@ import (
3232)
3333
3434const (
35- driverName = "sqlserver"
35+ schema = "sqlserver"
36+ stdDriverName = "sqlserver"
37+ azDriverName = "azuresql"
3638
3739 errNotSupported = "%s not supported by MSSQL client"
40+ fedauth = "fedauth"
3841)
3942
4043type mssqlDB struct {
4144 dsn string
4245 endpoint string
4346 port string
47+ driver string
4448}
4549
4650// New returns a new mssql database client.
4751func New (creds map [string ][]byte , database string ) xsql.DB {
4852 endpoint := string (creds [xpv1 .ResourceCredentialsSecretEndpointKey ])
4953 port := string (creds [xpv1 .ResourceCredentialsSecretPortKey ])
50-
54+ driver := stdDriverName
5155 host := endpoint
5256 if port != "" {
5357 host = fmt .Sprintf ("%s:%s" , endpoint , port )
@@ -57,16 +61,36 @@ func New(creds map[string][]byte, database string) xsql.DB {
5761 if database != "" {
5862 query .Add ("database" , database )
5963 }
60- u := & url.URL {
61- Scheme : driverName ,
62- User : url .UserPassword (string (creds [xpv1 .ResourceCredentialsSecretUserKey ]), string (creds [xpv1 .ResourceCredentialsSecretPasswordKey ])),
63- Host : host ,
64- RawQuery : query .Encode (),
64+ var u * url.URL
65+ if val , ok := creds [fedauth ]; ok {
66+ authType := string (val )
67+ query .Add (fedauth , authType )
68+ if authType == "ActiveDirectoryServicePrincipal" || authType == "ActiveDirectoryApplication" || authType == "ActiveDirectoryPassword" {
69+ query .Add ("password" , string (creds [xpv1 .ResourceCredentialsSecretPasswordKey ]))
70+ }
71+ if val , ok := creds [xpv1 .ResourceCredentialsSecretUserKey ]; ok {
72+ query .Add ("user id" , string (val ))
73+ }
74+ u = & url.URL {
75+ Scheme : schema ,
76+ Host : host ,
77+ RawQuery : query .Encode (),
78+ }
79+ driver = azDriverName
80+ } else {
81+
82+ u = & url.URL {
83+ Scheme : schema ,
84+ User : url .UserPassword (string (creds [xpv1 .ResourceCredentialsSecretUserKey ]), string (creds [xpv1 .ResourceCredentialsSecretPasswordKey ])),
85+ Host : host ,
86+ RawQuery : query .Encode (),
87+ }
6588 }
6689 return mssqlDB {
6790 dsn : u .String (),
6891 endpoint : endpoint ,
6992 port : port ,
93+ driver : driver ,
7094 }
7195}
7296
@@ -77,7 +101,7 @@ func (c mssqlDB) ExecTx(_ context.Context, _ []xsql.Query) error {
77101
78102// Exec the supplied query.
79103func (c mssqlDB ) Exec (ctx context.Context , q xsql.Query ) error {
80- d , err := sql .Open (driverName , c .dsn )
104+ d , err := sql .Open (c . driver , c .dsn )
81105 if err != nil {
82106 return err
83107 }
@@ -89,7 +113,7 @@ func (c mssqlDB) Exec(ctx context.Context, q xsql.Query) error {
89113
90114// Query the supplied query.
91115func (c mssqlDB ) Query (ctx context.Context , q xsql.Query ) (* sql.Rows , error ) {
92- d , err := sql .Open (driverName , c .dsn )
116+ d , err := sql .Open (c . driver , c .dsn )
93117 if err != nil {
94118 return nil , err
95119 }
@@ -100,7 +124,7 @@ func (c mssqlDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) {
100124
101125// Scan the results of the supplied query into the supplied destination.
102126func (c mssqlDB ) Scan (ctx context.Context , q xsql.Query , dest ... interface {}) error {
103- db , err := sql .Open (driverName , c .dsn )
127+ db , err := sql .Open (c . driver , c .dsn )
104128 if err != nil {
105129 return err
106130 }
0 commit comments