Skip to content

Commit 7adb6d4

Browse files
authored
Fix Azure authentication for dev and staging workspaces (#1607)
* Fix Azure authentication for dev and staging workspaces * use env variable * Fix client attributes tests * Rename AzureDatabricksResourceId to AzureDatabricksLoginAppId * Simplify GetAzureDatabricksLoginAppId
1 parent 152867a commit 7adb6d4

File tree

5 files changed

+58
-12
lines changed

5 files changed

+58
-12
lines changed

common/azure_auth.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,15 @@ import (
1717
)
1818

1919
// List of management information
20-
const armDatabricksResourceID string = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
20+
const azureDatabricksProdLoginAppID string = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
21+
22+
func (aa *DatabricksClient) GetAzureDatabricksLoginAppId() string {
23+
if aa.AzureDatabricksLoginAppId != "" {
24+
return aa.AzureDatabricksLoginAppId
25+
}
26+
return azureDatabricksProdLoginAppID
27+
}
2128

22-
//
2329
func (aa *DatabricksClient) GetAzureJwtProperty(key string) (any, error) {
2430
if !aa.IsAzure() {
2531
return "", fmt.Errorf("can't get Azure JWT token in non-Azure environment")
@@ -146,6 +152,7 @@ func (aa *DatabricksClient) simpleAADRequestVisitor(
146152
if err != nil {
147153
return nil, fmt.Errorf("cannot get workspace: %w", err)
148154
}
155+
armDatabricksResourceID := aa.GetAzureDatabricksLoginAppId()
149156
platformAuthorizer, err := authorizerFactory(armDatabricksResourceID)
150157
if err != nil {
151158
return nil, fmt.Errorf("cannot authorize databricks: %w", err)
@@ -217,6 +224,7 @@ func (aa *DatabricksClient) getClientSecretAuthorizer(resource string) (autorest
217224
if aa.azureAuthorizer != nil {
218225
return aa.azureAuthorizer, nil
219226
}
227+
armDatabricksResourceID := aa.GetAzureDatabricksLoginAppId()
220228
if resource != armDatabricksResourceID {
221229
es := auth.EnvironmentSettings{
222230
Values: map[string]string{

common/azure_auth_test.go

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,14 @@ func TestGetClientSecretAuthorizer(t *testing.T) {
7070
env, err := aa.getAzureEnvironment()
7171
require.NoError(t, err)
7272
aa.AzureEnvironment = &env
73-
auth, err := aa.getClientSecretAuthorizer(armDatabricksResourceID)
73+
auth, err := aa.getClientSecretAuthorizer(azureDatabricksProdLoginAppID)
7474
require.Nil(t, auth)
7575
require.EqualError(t, err, "parameter 'clientID' cannot be empty")
7676

7777
aa.AzureTenantID = "a"
7878
aa.AzureClientID = "b"
7979
aa.AzureClientSecret = "c"
80-
auth, err = aa.getClientSecretAuthorizer(armDatabricksResourceID)
80+
auth, err = aa.getClientSecretAuthorizer(azureDatabricksProdLoginAppID)
8181
require.NotNil(t, auth)
8282
require.NoError(t, err)
8383

@@ -541,10 +541,46 @@ func TestSimpleAADRequestVisitor_FailPlatformAuth(t *testing.T) {
541541
},
542542
}).simpleAADRequestVisitor(context.Background(),
543543
func(resource string) (autorest.Authorizer, error) {
544-
if resource == armDatabricksResourceID {
544+
if resource == azureDatabricksProdLoginAppID {
545545
return nil, fmt.Errorf("🤨")
546546
}
547547
return autorest.NullAuthorizer{}, nil
548548
})
549549
assert.EqualError(t, err, "cannot authorize databricks: 🤨")
550550
}
551+
552+
func TestSimpleAADRequestVisitor_ProdLoginAppId(t *testing.T) {
553+
aa := DatabricksClient{
554+
Host: "abc.azuredatabricks.net",
555+
AzureEnvironment: &azure.Environment{
556+
ServiceManagementEndpoint: "x",
557+
},
558+
}
559+
_, err := aa.simpleAADRequestVisitor(context.Background(),
560+
func(resource string) (autorest.Authorizer, error) {
561+
if resource == "x" {
562+
return autorest.NullAuthorizer{}, nil
563+
}
564+
assert.Equal(t, azureDatabricksProdLoginAppID, resource)
565+
return autorest.NullAuthorizer{}, nil
566+
})
567+
assert.Nil(t, err)
568+
}
569+
570+
func TestSimpleAADRequestVisitor_LoginAppIdOverride(t *testing.T) {
571+
_, err := (&DatabricksClient{
572+
Host: "abc.azuredatabricks.net",
573+
AzureEnvironment: &azure.Environment{
574+
ServiceManagementEndpoint: "x",
575+
},
576+
AzureDatabricksLoginAppId: "y",
577+
}).simpleAADRequestVisitor(context.Background(),
578+
func(resource string) (autorest.Authorizer, error) {
579+
if resource == "x" {
580+
return autorest.NullAuthorizer{}, nil
581+
}
582+
assert.Equal(t, "y", resource)
583+
return autorest.NullAuthorizer{}, nil
584+
})
585+
assert.Nil(t, err)
586+
}

common/azure_cli_auth.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ func (aa *DatabricksClient) configureWithAzureCLI(ctx context.Context) (func(*ht
102102
return nil, nil
103103
}
104104
// verify that Azure CLI is authenticated
105+
armDatabricksResourceID := aa.GetAzureDatabricksLoginAppId()
105106
_, err := cli.GetTokenFromCLI(armDatabricksResourceID)
106107
if err != nil {
107108
if strings.Contains(err.Error(), "executable file not found") {

common/client.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,13 @@ type DatabricksClient struct {
5656
GoogleServiceAccount string `name:"google_service_account" env:"DATABRICKS_GOOGLE_SERVICE_ACCOUNT" auth:"google"`
5757
GoogleCredentials string `name:"google_credentials" env:"GOOGLE_CREDENTIALS" auth:"google,sensitive"`
5858

59-
AzureResourceID string `name:"azure_workspace_resource_id" env:"DATABRICKS_AZURE_RESOURCE_ID" auth:"azure"`
60-
AzureUseMSI bool `name:"azure_use_msi" env:"ARM_USE_MSI" auth:"azure"`
61-
AzureClientSecret string `name:"azure_client_secret" env:"ARM_CLIENT_SECRET" auth:"azure,sensitive"`
62-
AzureClientID string `name:"azure_client_id" env:"ARM_CLIENT_ID" auth:"azure"`
63-
AzureTenantID string `name:"azure_tenant_id" env:"ARM_TENANT_ID" auth:"azure"`
64-
AzurermEnvironment string `name:"azure_environment" env:"ARM_ENVIRONMENT"`
59+
AzureResourceID string `name:"azure_workspace_resource_id" env:"DATABRICKS_AZURE_RESOURCE_ID" auth:"azure"`
60+
AzureUseMSI bool `name:"azure_use_msi" env:"ARM_USE_MSI" auth:"azure"`
61+
AzureClientSecret string `name:"azure_client_secret" env:"ARM_CLIENT_SECRET" auth:"azure,sensitive"`
62+
AzureClientID string `name:"azure_client_id" env:"ARM_CLIENT_ID" auth:"azure"`
63+
AzureTenantID string `name:"azure_tenant_id" env:"ARM_TENANT_ID" auth:"azure"`
64+
AzurermEnvironment string `name:"azure_environment" env:"ARM_ENVIRONMENT"`
65+
AzureDatabricksLoginAppId string `name:"azure_login_app_id" env:"DATABRICKS_AZURE_LOGIN_APP_ID" auth:"azure"`
6566

6667
// When multiple auth attributes are available in the environment, use the auth type
6768
// specified by this argument. This argument also holds currently selected auth.

common/client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func TestDatabricksClient_FormatURL(t *testing.T) {
154154

155155
func TestClientAttributes(t *testing.T) {
156156
ca := ClientAttributes()
157-
assert.Len(t, ca, 21)
157+
assert.Len(t, ca, 22)
158158
}
159159

160160
func TestDatabricksClient_Authenticate(t *testing.T) {

0 commit comments

Comments
 (0)