Skip to content

Commit 0a055b3

Browse files
authored
[Fix] Use correct domain for Azure Gov and China (#4274)
## Changes <!-- Summary of your changes that are easy to understand --> Resolves #4272 ## Tests <!-- How is this tested? Please see the checklist below and also describe any other relevant tests --> - [x] `make test` run locally - [ ] relevant change in `docs/` folder - [ ] covered with integration tests in `internal/acceptance` - [x] relevant acceptance tests are passing - [ ] using Go SDK
1 parent a5eb85e commit 0a055b3

12 files changed

+97
-58
lines changed

storage/adls_gen1_mount.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type AzureADLSGen1Mount struct {
2020
}
2121

2222
// Source ...
23-
func (m AzureADLSGen1Mount) Source() string {
23+
func (m AzureADLSGen1Mount) Source(_ *common.DatabricksClient) string {
2424
return fmt.Sprintf("adl://%s.azuredatalakestore.net%s", m.StorageResource, m.Directory)
2525
}
2626

storage/adls_gen2_mount.go

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package storage
22

33
import (
44
"fmt"
5+
"strings"
56

67
"github.com/databricks/terraform-provider-databricks/common"
78
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
9+
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
810
)
911

1012
// AzureADLSGen2Mount describes the object for a azure datalake gen 2 storage mount
@@ -19,10 +21,23 @@ type AzureADLSGen2Mount struct {
1921
InitializeFileSystem bool `json:"initialize_file_system"`
2022
}
2123

24+
func getAzureDomain(client *common.DatabricksClient) string {
25+
domains := map[string]string{
26+
"PUBLIC": "core.windows.net",
27+
"USGOVERNMENT": "core.usgovcloudapi.net",
28+
"CHINA": "core.chinacloudapi.cn",
29+
}
30+
azureEnvironment := client.Config.Environment().AzureEnvironment.Name
31+
domain, ok := domains[strings.ToUpper(azureEnvironment)]
32+
if !ok {
33+
panic(fmt.Sprintf("Unknown Azure environment: '%s'", azureEnvironment))
34+
}
35+
return domain
36+
}
37+
2238
// Source returns ABFSS URI backing the mount
23-
func (m AzureADLSGen2Mount) Source() string {
24-
return fmt.Sprintf("abfss://%s@%s.dfs.core.windows.net%s",
25-
m.ContainerName, m.StorageAccountName, m.Directory)
39+
func (m AzureADLSGen2Mount) Source(client *common.DatabricksClient) string {
40+
return fmt.Sprintf("abfss://%s@%s.dfs.%s%s", m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory)
2641
}
2742

2843
func (m AzureADLSGen2Mount) Name() string {
@@ -106,5 +121,12 @@ func ResourceAzureAdlsGen2Mount() common.Resource {
106121
Required: true,
107122
ForceNew: true,
108123
},
124+
"environment": {
125+
Type: schema.TypeString,
126+
Optional: true,
127+
ForceNew: true,
128+
ValidateFunc: validation.StringInSlice([]string{"PUBLIC", "USGOVERNMENT", "CHINA"}, false),
129+
Default: "PUBLIC",
130+
},
109131
}))
110132
}

storage/adls_gen2_mount_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@ import (
1010

1111
"github.com/databricks/terraform-provider-databricks/qa"
1212
"github.com/stretchr/testify/assert"
13-
"github.com/stretchr/testify/require"
1413
)
1514

1615
func TestResourceAdlsGen2Mount_Create(t *testing.T) {
17-
d, err := qa.ResourceFixture{
16+
qa.ResourceFixture{
1817
Fixtures: []qa.HTTPFixture{
1918
{
2019
Method: "GET",
@@ -51,8 +50,9 @@ func TestResourceAdlsGen2Mount_Create(t *testing.T) {
5150
"initialize_file_system": true,
5251
},
5352
Create: true,
54-
}.Apply(t)
55-
require.NoError(t, err)
56-
assert.Equal(t, "this_mount", d.Id())
57-
assert.Equal(t, "abfss://[email protected]", d.Get("source"))
53+
Azure: true,
54+
}.ApplyAndExpectData(t, map[string]any{
55+
"id": "this_mount",
56+
"source": "abfss://[email protected]",
57+
})
5858
}

storage/aws_s3_mount.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type AWSIamMount struct {
1616
}
1717

1818
// Source ...
19-
func (m AWSIamMount) Source() string {
19+
func (m AWSIamMount) Source(_ *common.DatabricksClient) string {
2020
return fmt.Sprintf("s3a://%s", m.S3BucketName)
2121
}
2222

storage/azure_blob_mount.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ type AzureBlobMount struct {
1919
}
2020

2121
// Source ...
22-
func (m AzureBlobMount) Source() string {
23-
return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.core.windows.net%[3]s",
24-
m.ContainerName, m.StorageAccountName, m.Directory)
22+
func (m AzureBlobMount) Source(client *common.DatabricksClient) string {
23+
return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.%[3]s%[4]s",
24+
m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory)
2525
}
2626

2727
func (m AzureBlobMount) Name() string {

storage/azure_blob_mount_test.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
)
1515

1616
func TestResourceAzureBlobMountCreate(t *testing.T) {
17-
d, err := qa.ResourceFixture{
17+
qa.ResourceFixture{
1818
Fixtures: []qa.HTTPFixture{
1919
{
2020
Method: "GET",
@@ -50,11 +50,12 @@ func TestResourceAzureBlobMountCreate(t *testing.T) {
5050
"token_secret_key": "g",
5151
"token_secret_scope": "h",
5252
},
53+
Azure: true,
5354
Create: true,
54-
}.Apply(t)
55-
require.NoError(t, err)
56-
assert.Equal(t, "e", d.Id())
57-
assert.Equal(t, "wasbs://[email protected]/d", d.Get("source"))
55+
}.ApplyAndExpectData(t, map[string]any{
56+
"id": "e",
57+
"source": "wasbs://[email protected]/d",
58+
})
5859
}
5960

6061
func TestResourceAzureBlobMountCreate_Error(t *testing.T) {
@@ -86,6 +87,7 @@ func TestResourceAzureBlobMountCreate_Error(t *testing.T) {
8687
"token_secret_scope": "h",
8788
},
8889
Create: true,
90+
Azure: true,
8991
}.Apply(t)
9092
require.EqualError(t, err, "Some error")
9193
assert.Equal(t, "e", d.Id())
@@ -124,8 +126,9 @@ func TestResourceAzureBlobMountRead(t *testing.T) {
124126
"token_secret_key": "g",
125127
"token_secret_scope": "h",
126128
},
127-
ID: "e",
128-
Read: true,
129+
ID: "e",
130+
Read: true,
131+
Azure: true,
129132
}.Apply(t)
130133
require.NoError(t, err)
131134
assert.Equal(t, "e", d.Id())
@@ -165,6 +168,7 @@ func TestResourceAzureBlobMountRead_NotFound(t *testing.T) {
165168
ID: "e",
166169
Read: true,
167170
Removed: true,
171+
Azure: true,
168172
}.ApplyNoError(t)
169173
}
170174

@@ -198,8 +202,9 @@ func TestResourceAzureBlobMountRead_Error(t *testing.T) {
198202
"token_secret_key": "g",
199203
"token_secret_scope": "h",
200204
},
201-
ID: "e",
202-
Read: true,
205+
ID: "e",
206+
Azure: true,
207+
Read: true,
203208
}.Apply(t)
204209
require.EqualError(t, err, "Some error")
205210
assert.Equal(t, "e", d.Id())
@@ -239,6 +244,7 @@ func TestResourceAzureBlobMountDelete(t *testing.T) {
239244
},
240245
ID: "e",
241246
Delete: true,
247+
Azure: true,
242248
}.Apply(t)
243249
require.NoError(t, err)
244250
assert.Equal(t, "e", d.Id())

storage/generic_mounts.go

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ func (m GenericMount) getBlock() Mount {
4242
}
4343

4444
// Source returns URI backing the mount
45-
func (m GenericMount) Source() string {
45+
func (m GenericMount) Source(client *common.DatabricksClient) string {
4646
if block := m.getBlock(); block != nil {
47-
return block.Source()
47+
return block.Source(client)
4848
}
4949
return m.URI
5050
}
@@ -96,7 +96,7 @@ func parseStorageContainerId(rid string) (string, string, error) {
9696
return match[3], match[4], nil
9797
}
9898

99-
func getContainerDefaults(d *schema.ResourceData, allowed_schemas []string, suffix string) (string, string, error) {
99+
func getContainerDefaults(d *schema.ResourceData) (string, string, error) {
100100
rid := d.Get("resource_id").(string)
101101
if rid != "" {
102102
acc, cont, err := parseStorageContainerId(rid)
@@ -134,9 +134,8 @@ type AzureADLSGen2MountGeneric struct {
134134
}
135135

136136
// Source returns ABFSS URI backing the mount
137-
func (m *AzureADLSGen2MountGeneric) Source() string {
138-
return fmt.Sprintf("abfss://%s@%s.dfs.core.windows.net%s",
139-
m.ContainerName, m.StorageAccountName, m.Directory)
137+
func (m *AzureADLSGen2MountGeneric) Source(client *common.DatabricksClient) string {
138+
return fmt.Sprintf("abfss://%s@%s.dfs.%s%s", m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory)
140139
}
141140

142141
func (m *AzureADLSGen2MountGeneric) Name() string {
@@ -145,7 +144,7 @@ func (m *AzureADLSGen2MountGeneric) Name() string {
145144

146145
func (m *AzureADLSGen2MountGeneric) ValidateAndApplyDefaults(d *schema.ResourceData, client *common.DatabricksClient) error {
147146
if m.ContainerName == "" || m.StorageAccountName == "" {
148-
acc, cont, err := getContainerDefaults(d, []string{"abfs", "abfss"}, "dfs.core.windows.net")
147+
acc, cont, err := getContainerDefaults(d)
149148
if err != nil {
150149
return err
151150
}
@@ -194,7 +193,7 @@ type AzureADLSGen1MountGeneric struct {
194193
}
195194

196195
// Source ...
197-
func (m *AzureADLSGen1MountGeneric) Source() string {
196+
func (m *AzureADLSGen1MountGeneric) Source(_ *common.DatabricksClient) string {
198197
return fmt.Sprintf("adl://%s.azuredatalakestore.net%s", m.StorageResource, m.Directory)
199198
}
200199

@@ -237,10 +236,9 @@ func (m *AzureADLSGen1MountGeneric) Config(client *common.DatabricksClient) map[
237236
aadEndpoint := client.Config.Environment().AzureActiveDirectoryEndpoint()
238237
return map[string]string{
239238
m.PrefixType + ".oauth2.access.token.provider.type": "ClientCredential",
240-
241-
m.PrefixType + ".oauth2.client.id": m.ClientID,
242-
m.PrefixType + ".oauth2.credential": fmt.Sprintf("{{secrets/%s/%s}}", m.SecretScope, m.SecretKey),
243-
m.PrefixType + ".oauth2.refresh.url": fmt.Sprintf("%s%s/oauth2/token", aadEndpoint, m.TenantID),
239+
m.PrefixType + ".oauth2.client.id": m.ClientID,
240+
m.PrefixType + ".oauth2.credential": fmt.Sprintf("{{secrets/%s/%s}}", m.SecretScope, m.SecretKey),
241+
m.PrefixType + ".oauth2.refresh.url": fmt.Sprintf("%s%s/oauth2/token", aadEndpoint, m.TenantID),
244242
}
245243
}
246244

@@ -257,9 +255,9 @@ type AzureBlobMountGeneric struct {
257255
}
258256

259257
// Source ...
260-
func (m *AzureBlobMountGeneric) Source() string {
261-
return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.core.windows.net%[3]s",
262-
m.ContainerName, m.StorageAccountName, m.Directory)
258+
func (m *AzureBlobMountGeneric) Source(client *common.DatabricksClient) string {
259+
return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.%[3]s%[4]s",
260+
m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory)
263261
}
264262

265263
func (m *AzureBlobMountGeneric) Name() string {
@@ -268,7 +266,7 @@ func (m *AzureBlobMountGeneric) Name() string {
268266

269267
func (m *AzureBlobMountGeneric) ValidateAndApplyDefaults(d *schema.ResourceData, client *common.DatabricksClient) error {
270268
if m.ContainerName == "" || m.StorageAccountName == "" {
271-
acc, cont, err := getContainerDefaults(d, []string{"wasb", "wasbs"}, "blob.core.windows.net")
269+
acc, cont, err := getContainerDefaults(d)
272270
if err != nil {
273271
return err
274272
}

storage/gs.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ type GSMount struct {
1919
}
2020

2121
// Source ...
22-
func (m GSMount) Source() string {
22+
func (m GSMount) Source(_ *common.DatabricksClient) string {
2323
return fmt.Sprintf("gs://%s", m.BucketName)
2424
}
2525

storage/mounts.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020

2121
// Mount exposes generic url & extra config map options
2222
type Mount interface {
23-
Source() string
23+
Source(client *common.DatabricksClient) string
2424
Config(client *common.DatabricksClient) map[string]string
2525

2626
Name() string
@@ -96,7 +96,7 @@ func (mp MountPoint) Mount(mo Mount, client *common.DatabricksClient) (source st
9696
raise e
9797
mount_source = safe_mount("/mnt/%s", "%v", %s, "%s")
9898
dbutils.notebook.exit(mount_source)
99-
`, mp.Name, mo.Source(), extraConfigs, mp.EncryptionType) // lgtm[go/unsafe-quoting]
99+
`, mp.Name, mo.Source(client), extraConfigs, mp.EncryptionType) // lgtm[go/unsafe-quoting]
100100
result := mp.Exec.Execute(mp.ClusterID, "python", command)
101101
return result.Text(), result.Err()
102102
}
@@ -235,7 +235,7 @@ func mountCreate(tpl any, r common.Resource) func(context.Context, *schema.Resou
235235
if err != nil {
236236
return err
237237
}
238-
log.Printf("[INFO] Mounting %s at /mnt/%s", mountConfig.Source(), d.Id())
238+
log.Printf("[INFO] Mounting %s at /mnt/%s", mountConfig.Source(client), d.Id())
239239
source, err := mountPoint.Mount(mountConfig, client)
240240
if err != nil {
241241
return err

storage/mounts_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ func testMountFuncHelper(t *testing.T, mountFunc func(mp MountPoint, mount Mount
7474

7575
type mockMount struct{}
7676

77-
func (t mockMount) Source() string { return "fake-mount" }
78-
func (t mockMount) Name() string { return "fake-mount" }
77+
func (t mockMount) Source(_ *common.DatabricksClient) string { return "fake-mount" }
78+
func (t mockMount) Name() string { return "fake-mount" }
7979
func (t mockMount) Config(client *common.DatabricksClient) map[string]string {
8080
return map[string]string{"fake-key": "fake-value"}
8181
}
@@ -84,6 +84,14 @@ func (m mockMount) ValidateAndApplyDefaults(d *schema.ResourceData, client *comm
8484
}
8585

8686
func TestMountPoint_Mount(t *testing.T) {
87+
client := common.DatabricksClient{
88+
DatabricksClient: &client.DatabricksClient{
89+
Config: &config.Config{
90+
Host: ".",
91+
Token: ".",
92+
},
93+
},
94+
}
8795
mount := mockMount{}
8896
expectedMountSource := "fake-mount"
8997
expectedMountConfig := `{"fake-key":"fake-value"}`
@@ -108,14 +116,6 @@ func TestMountPoint_Mount(t *testing.T) {
108116
dbutils.notebook.exit(mount_source)
109117
`, mountName, expectedMountSource, expectedMountConfig)
110118
testMountFuncHelper(t, func(mp MountPoint, mount Mount) (s string, e error) {
111-
client := common.DatabricksClient{
112-
DatabricksClient: &client.DatabricksClient{
113-
Config: &config.Config{
114-
Host: ".",
115-
Token: ".",
116-
},
117-
},
118-
}
119119
return mp.Mount(mount, &client)
120120
}, mount, mountName, expectedCommand)
121121
}

0 commit comments

Comments
 (0)