Skip to content

Commit 2d6a271

Browse files
authored
Merge pull request #163 from sadedil/master
Extract database name from query instead of path for mssql
2 parents 19a20bb + ade9e44 commit 2d6a271

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

job.go

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,33 @@ func (j *Job) updateConnections() {
356356
if u.User != nil {
357357
user = u.User.Username()
358358
}
359+
360+
// For SQL Server connections,
361+
// url.path is reserved for sql server instance not the database name
362+
// database name can be specified in multiple ways:
363+
// 1. Query parameter: database=dbname
364+
// 2. Query parameter: initial catalog=dbname
365+
database := ""
366+
if strings.HasPrefix(conn, "sqlserver://") {
367+
// Check for 'database' parameter first
368+
if dbParam := getQueryStringCaseInsensitive(u.Query(), "database"); dbParam != "" {
369+
database = dbParam
370+
} else if catalogParam := getQueryStringCaseInsensitive(u.Query(), "initial catalog"); catalogParam != "" {
371+
// 'initial catalog' is an alternative to 'database' parameter
372+
database = catalogParam
373+
}
374+
} else {
375+
database = strings.TrimPrefix(u.Path, "/")
376+
}
377+
359378
// we expose some of the connection variables as labels, so we need to
360379
// remember them
361380
newConn := &connection{
362381
conn: nil,
363382
url: conn,
364383
driver: u.Scheme,
365384
host: u.Host,
366-
database: strings.TrimPrefix(u.Path, "/"),
385+
database: database,
367386
user: user,
368387
}
369388
if newConn.driver == "athena" {
@@ -387,11 +406,11 @@ func (j *Job) updateConnections() {
387406
privateKeyPath := os.ExpandEnv(queryParams.Get("private_key_file"))
388407

389408
cfg := &gosnowflake.Config{
390-
Account: u.Host,
391-
User: u.User.Username(),
392-
Role: queryParams.Get("role"),
409+
Account: u.Host,
410+
User: u.User.Username(),
411+
Role: queryParams.Get("role"),
393412
Database: queryParams.Get("database"),
394-
Schema: queryParams.Get("schema"),
413+
Schema: queryParams.Get("schema"),
395414
}
396415

397416
if privateKeyPath != "" {
@@ -683,3 +702,13 @@ func (c *connection) connect(job *Job) error {
683702
c.conn = conn
684703
return nil
685704
}
705+
706+
func getQueryStringCaseInsensitive(values url.Values, key string) string {
707+
key = strings.ToLower(key)
708+
for k, v := range values {
709+
if strings.ToLower(k) == key && len(v) > 0 {
710+
return v[0]
711+
}
712+
}
713+
return ""
714+
}

0 commit comments

Comments
 (0)