Skip to content

Commit 95bfae3

Browse files
authored
Sanitize host field prior to auth flow (#1385)
1 parent 6291f7e commit 95bfae3

File tree

3 files changed

+81
-7
lines changed

3 files changed

+81
-7
lines changed

common/client.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"log"
99
"net/http"
10+
"net/url"
1011
"os"
1112
"reflect"
1213
"strings"
@@ -237,6 +238,11 @@ func (c *DatabricksClient) Authenticate(ctx context.Context) error {
237238
if c.authVisitor != nil {
238239
return nil
239240
}
241+
// Fix host prior to auth, because it may be used in the OIDC flow as "audience" field.
242+
// If necessary, this function adds a scheme and strips a trailing slash.
243+
if err := c.fixHost(); err != nil {
244+
return err
245+
}
240246
type auth struct {
241247
configure func(context.Context) (func(*http.Request) error, error)
242248
name string
@@ -326,12 +332,34 @@ func (c *DatabricksClient) niceAuthError(message string) error {
326332
return fmt.Errorf("%s%s. Please check %s for details", message, info, docUrl)
327333
}
328334

329-
func (c *DatabricksClient) fixHost() {
330-
if c.Host != "" && !(strings.HasPrefix(c.Host, "https://") || strings.HasPrefix(c.Host, "http://")) {
331-
// azurerm_databricks_workspace.*.workspace_url is giving URL without scheme
332-
// so that is why this line is here
333-
c.Host = "https://" + c.Host
335+
func (c *DatabricksClient) fixHost() error {
336+
// Nothing to fix if the host isn't set.
337+
if c.Host == "" {
338+
return nil
339+
}
340+
341+
u, err := url.Parse(c.Host)
342+
if err != nil {
343+
return err
344+
}
345+
346+
// If the host is empty, assume the scheme wasn't included.
347+
if u.Host == "" {
348+
u, err = url.Parse("https://" + c.Host)
349+
if err != nil {
350+
return err
351+
}
334352
}
353+
354+
// Create new instance to ensure other fields are initialized as empty.
355+
u = &url.URL{
356+
Scheme: u.Scheme,
357+
Host: u.Host,
358+
}
359+
360+
// Store sanitized version of c.Host.
361+
c.Host = u.String()
362+
return nil
335363
}
336364

337365
func (c *DatabricksClient) configureWithPat(ctx context.Context) (func(*http.Request) error, error) {

common/client_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,49 @@ func TestConfigAttributeSetNonsense(t *testing.T) {
297297
}).Set(&DatabricksClient{}, 1)
298298
assert.EqualError(t, err, "cannot set of unknown type Chan")
299299
}
300+
301+
func TestDatabricksClientFixHost(t *testing.T) {
302+
hostForInput := func(in string) (string, error) {
303+
client := &DatabricksClient{
304+
Host: in,
305+
}
306+
if err := client.fixHost(); err != nil {
307+
return "", err
308+
}
309+
return client.Host, nil
310+
}
311+
312+
{
313+
// Strip trailing slash.
314+
out, err := hostForInput("https://accounts.gcp.databricks.com/")
315+
assert.Nil(t, err)
316+
assert.Equal(t, out, "https://accounts.gcp.databricks.com")
317+
}
318+
319+
{
320+
// Keep port.
321+
out, err := hostForInput("https://accounts.gcp.databricks.com:443")
322+
assert.Nil(t, err)
323+
assert.Equal(t, out, "https://accounts.gcp.databricks.com:443")
324+
}
325+
326+
{
327+
// Default scheme.
328+
out, err := hostForInput("accounts.gcp.databricks.com")
329+
assert.Nil(t, err)
330+
assert.Equal(t, out, "https://accounts.gcp.databricks.com")
331+
}
332+
333+
{
334+
// Default scheme with port.
335+
out, err := hostForInput("accounts.gcp.databricks.com:443")
336+
assert.Nil(t, err)
337+
assert.Equal(t, out, "https://accounts.gcp.databricks.com:443")
338+
}
339+
340+
{
341+
// Return error.
342+
_, err := hostForInput("://@@@accounts.gcp.databricks.com/")
343+
assert.NotNil(t, err)
344+
}
345+
}

provider/provider_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ func TestConfig_AzurePAT(t *testing.T) {
198198
host: "https://adb-xxx.y.azuredatabricks.net/",
199199
token: "y",
200200
assertAzure: true,
201-
assertHost: "https://adb-xxx.y.azuredatabricks.net/",
201+
assertHost: "https://adb-xxx.y.azuredatabricks.net",
202202
assertAuth: "pat",
203203
}.apply(t)
204204
}
@@ -244,7 +244,7 @@ func TestConfig_PatFromDatabricksCfg(t *testing.T) {
244244
env: map[string]string{
245245
"HOME": "../common/testdata",
246246
},
247-
assertHost: "https://dbc-XXXXXXXX-YYYY.cloud.databricks.com/",
247+
assertHost: "https://dbc-XXXXXXXX-YYYY.cloud.databricks.com",
248248
assertAuth: "databricks-cli",
249249
}.apply(t)
250250
}

0 commit comments

Comments
 (0)