Skip to content

Commit 0f3a5d4

Browse files
committed
added the customer auth function to authenticate in provider.go during the configuration process, and altered the other resources to refer to the pointer of the client object rather than the value
1 parent 291b5d7 commit 0f3a5d4

38 files changed

+181
-160
lines changed

databricks/azure_auth.go

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@ package databricks
33
import (
44
"encoding/json"
55
"fmt"
6-
"log"
7-
"net/http"
8-
urlParse "net/url"
9-
106
"github.com/Azure/go-autorest/autorest/adal"
117
"github.com/Azure/go-autorest/autorest/azure"
128
"github.com/databrickslabs/databricks-terraform/client/service"
9+
"log"
10+
"net/http"
11+
urlParse "net/url"
1312
)
1413

1514
// List of management information
@@ -78,27 +77,26 @@ func (a *AzureAuth) getWorkspaceID(config *service.DBApiClientConfig) error {
7877
log.Println("[DEBUG] Getting Workspace ID via management token.")
7978
// Escape all the ids
8079
url := fmt.Sprintf("https://management.azure.com/subscriptions/%s/resourceGroups/%s"+
81-
"/providers/Microsoft.Databricks/workspaces/%s?api-version=2018-04-01",
80+
"/providers/Microsoft.Databricks/workspaces/%s",
8281
urlParse.PathEscape(a.TokenPayload.SubscriptionID),
8382
urlParse.PathEscape(a.TokenPayload.ResourceGroup),
8483
urlParse.PathEscape(a.TokenPayload.WorkspaceName))
85-
payload := &WorkspaceRequest{
86-
Properties: &WsProps{ManagedResourceGroupID: "/subscriptions/" + a.TokenPayload.SubscriptionID + "/resourceGroups/" + a.TokenPayload.ManagedResourceGroup},
87-
Name: a.TokenPayload.WorkspaceName,
88-
Location: a.TokenPayload.AzureRegion,
89-
}
9084
headers := map[string]string{
9185
"Content-Type": "application/json",
9286
"cache-control": "no-cache",
9387
"Authorization": "Bearer " + a.ManagementToken,
9488
}
95-
89+
type apiVersion struct {
90+
ApiVersion string `url:"api-version"`
91+
}
92+
uriPayload := apiVersion{
93+
ApiVersion: "2018-04-01",
94+
}
9695
var responseMap map[string]interface{}
97-
resp, err := service.PerformQuery(config, http.MethodPut, url, "2.0", headers, true, true, payload, nil)
96+
resp, err := service.PerformQuery(config, http.MethodGet, url, "2.0", headers, false, true, uriPayload, nil)
9897
if err != nil {
9998
return err
10099
}
101-
102100
err = json.Unmarshal(resp, &responseMap)
103101
if err != nil {
104102
return err
@@ -170,53 +168,36 @@ func (a *AzureAuth) getWorkspaceAccessToken(config *service.DBApiClientConfig) e
170168
// 2. Get Workspace ID
171169
// 3. Get Azure Databricks Platform OAuth Token using Databricks resource id
172170
// 4. Get Azure Databricks Workspace Personal Access Token for the SP (60 min duration)
173-
func (a *AzureAuth) initWorkspaceAndGetClient(config *service.DBApiClientConfig) (service.DBApiClient, error) {
174-
var dbClient service.DBApiClient
171+
func (a *AzureAuth) initWorkspaceAndGetClient(config *service.DBApiClientConfig) error {
172+
//var dbClient service.DBApiClient
175173

176174
// Get management token
177175
err := a.getManagementToken(config)
178176
if err != nil {
179-
return dbClient, err
177+
return err
180178
}
181179

182180
// Get workspace access token
183181
err = a.getWorkspaceID(config)
184182
if err != nil {
185-
return dbClient, err
183+
return err
186184
}
187185

188186
// Get platform token
189187
err = a.getADBPlatformToken(config)
190188
if err != nil {
191-
return dbClient, err
189+
return err
192190
}
193191

194192
// Get workspace personal access token
195193
err = a.getWorkspaceAccessToken(config)
196194
if err != nil {
197-
return dbClient, err
195+
return err
198196
}
199197

200-
var newOption service.DBApiClientConfig
201-
202-
// TODO: Eventually change this to include new Databricks domain names. May have to add new vars and/or deprecate existing args.
203-
newOption.Host = "https://" + a.TokenPayload.AzureRegion + ".azuredatabricks.net"
204-
newOption.Token = a.AdbAccessToken
198+
//// TODO: Eventually change this to include new Databricks domain names. May have to add new vars and/or deprecate existing args.
199+
config.Host = "https://" + a.TokenPayload.AzureRegion + ".azuredatabricks.net"
200+
config.Token = a.AdbAccessToken
205201

206-
// Headers to use aad tokens, hidden till tokens support secrets, scopes and acls
207-
//newOption.DefaultHeaders = map[string]string{
208-
// //"Content-Type": "application/x-www-form-urlencoded",
209-
// "X-Databricks-Azure-Workspace-Resource-Id": a.AdbWorkspaceResourceID,
210-
// "X-Databricks-Azure-SP-Management-Token": a.ManagementToken,
211-
// "cache-control": "no-cache",
212-
//}
213-
dbClient.SetConfig(&newOption)
214-
215-
// Spin for a while while the workspace comes up and starts behaving.
216-
_, err = dbClient.Clusters().ListNodeTypes()
217-
if err != nil {
218-
return dbClient, err
219-
}
220-
221-
return dbClient, err
202+
return nil
222203
}

databricks/azure_auth_test.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ func TestAzureAuthCreateApiToken(t *testing.T) {
2222

2323
azureAuth := AzureAuth{
2424
TokenPayload: &TokenPayload{
25-
ManagedResourceGroup: os.Getenv("TEST_MANAGED_RESOURCE_GROUP"),
26-
AzureRegion: "centralus",
27-
WorkspaceName: os.Getenv("TEST_WORKSPACE_NAME"),
28-
ResourceGroup: os.Getenv("TEST_RESOURCE_GROUP"),
25+
ManagedResourceGroup: os.Getenv("DATABRICKS_AZURE_MANAGED_RESOURCE_GROUP"),
26+
AzureRegion: os.Getenv("AZURE_REGION"),
27+
WorkspaceName: os.Getenv("DATABRICKS_AZURE_WORKSPACE_NAME"),
28+
ResourceGroup: os.Getenv("DATABRICKS_AZURE_RESOURCE_GROUP"),
2929
},
3030
ManagementToken: "",
3131
AdbWorkspaceResourceID: "",
@@ -36,10 +36,11 @@ func TestAzureAuthCreateApiToken(t *testing.T) {
3636
azureAuth.TokenPayload.TenantID = os.Getenv("DATABRICKS_AZURE_TENANT_ID")
3737
azureAuth.TokenPayload.ClientID = os.Getenv("DATABRICKS_AZURE_CLIENT_ID")
3838
azureAuth.TokenPayload.ClientSecret = os.Getenv("DATABRICKS_AZURE_CLIENT_SECRET")
39-
option := GetIntegrationDBClientOptions()
40-
api, err := azureAuth.initWorkspaceAndGetClient(option)
39+
config := GetIntegrationDBClientOptions()
40+
err := azureAuth.initWorkspaceAndGetClient(config)
4141
assert.NoError(t, err, err)
42-
42+
api := service.DBApiClient{}
43+
api.SetConfig(config)
4344
instancePoolInfo, instancePoolErr := api.InstancePools().Create(model.InstancePool{
4445
InstancePoolName: "my_instance_pool",
4546
MinIdleInstances: 0,

databricks/data_source_databricks_dbfs_file.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func dataSourceDBFSFile() *schema.Resource {
3535
func dataSourceDBFSFileRead(d *schema.ResourceData, m interface{}) error {
3636
path := d.Get("path").(string)
3737
limitFileSize := d.Get("limit_file_size").(bool)
38-
client := m.(service.DBApiClient)
38+
client := m.(*service.DBApiClient)
3939

4040
fileInfo, err := client.DBFS().Status(path)
4141
if err != nil {

databricks/data_source_databricks_dbfs_file_paths.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func dataSourceDBFSFilePaths() *schema.Resource {
4343
func dataSourceDBFSFilePathsRead(d *schema.ResourceData, m interface{}) error {
4444
path := d.Get("path").(string)
4545
recursive := d.Get("recursive").(bool)
46-
client := m.(service.DBApiClient)
46+
client := m.(*service.DBApiClient)
4747

4848
paths, err := client.DBFS().List(path, recursive)
4949
if err != nil {

databricks/data_source_databricks_default_user_roles.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
func dataSourceDefaultUserRoles() *schema.Resource {
99
return &schema.Resource{
1010
Read: func(d *schema.ResourceData, m interface{}) error {
11-
client := m.(service.DBApiClient)
11+
client := m.(*service.DBApiClient)
1212

1313
defaultRolesUserName := d.Get("default_username").(string)
1414
metaUser, err := client.Users().GetOrCreateDefaultMetaUser(defaultRolesUserName, defaultRolesUserName, true)

databricks/data_source_databricks_notebook.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func dataSourceNotebook() *schema.Resource {
5050
func dataSourceNotebookRead(d *schema.ResourceData, m interface{}) error {
5151
path := d.Get("path").(string)
5252
format := d.Get("format").(string)
53-
client := m.(service.DBApiClient)
53+
client := m.(*service.DBApiClient)
5454

5555
notebookInfo, err := client.Notebooks().Read(path)
5656
if err != nil {

databricks/data_source_databricks_notebook_paths.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func dataSourceNotebookPathsRead(d *schema.ResourceData, m interface{}) error {
4747
path := d.Get("path").(string)
4848
recursive := d.Get("recursive").(bool)
4949

50-
client := m.(service.DBApiClient)
50+
client := m.(*service.DBApiClient)
5151

5252
notebookList, err := client.Notebooks().List(path, recursive)
5353
if err != nil {

databricks/data_source_databricks_zones.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
func dataSourceClusterZones() *schema.Resource {
99
return &schema.Resource{
1010
Read: func(d *schema.ResourceData, m interface{}) error {
11-
client := m.(service.DBApiClient)
11+
client := m.(*service.DBApiClient)
1212

1313
zonesInfo, err := client.Clusters().ListZones()
1414
if err != nil {

databricks/mounts.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ import (
1111

1212
// Mount interface describes the functionality of any mount which is create, read and delete
1313
type Mount interface {
14-
Create(client service.DBApiClient, clusterID string) error
15-
Delete(client service.DBApiClient, clusterID string) error
16-
Read(client service.DBApiClient, clusterID string) (string, error)
14+
Create(client *service.DBApiClient, clusterID string) error
15+
Delete(client *service.DBApiClient, clusterID string) error
16+
Read(client *service.DBApiClient, clusterID string) (string, error)
1717
}
1818

1919
// AWSIamMount describes the object for a aws mount using iam role
@@ -29,7 +29,7 @@ func NewAWSIamMount(s3BucketName string, mountName string) *AWSIamMount {
2929
}
3030

3131
// Create creates an aws iam mount given a cluster ID
32-
func (m AWSIamMount) Create(client service.DBApiClient, clusterID string) error {
32+
func (m AWSIamMount) Create(client *service.DBApiClient, clusterID string) error {
3333
iamMountCommand := fmt.Sprintf(`
3434
dbutils.fs.mount("s3a://%s", "/mnt/%s")
3535
dbutils.fs.ls("/mnt/%s")
@@ -47,7 +47,7 @@ dbutils.notebook.exit("success")
4747
}
4848

4949
// Delete deletes an aws iam mount given a cluster ID
50-
func (m AWSIamMount) Delete(client service.DBApiClient, clusterID string) error {
50+
func (m AWSIamMount) Delete(client *service.DBApiClient, clusterID string) error {
5151
iamMountCommand := fmt.Sprintf(`
5252
dbutils.fs.unmount("/mnt/%s")
5353
dbutils.fs.refreshMounts()
@@ -65,7 +65,7 @@ dbutils.notebook.exit("success")
6565
}
6666

6767
// Read verifies an aws iam mount given a cluster ID
68-
func (m AWSIamMount) Read(client service.DBApiClient, clusterID string) (string, error) {
68+
func (m AWSIamMount) Read(client *service.DBApiClient, clusterID string) (string, error) {
6969
iamMountCommand := fmt.Sprintf(`
7070
dbutils.fs.ls("/mnt/%s")
7171
for mount in dbutils.fs.mounts():
@@ -108,7 +108,7 @@ func NewAzureBlobMount(containerName string, storageAccountName string, director
108108
}
109109

110110
// Create creates a azure blob storage mount given a cluster id
111-
func (m AzureBlobMount) Create(client service.DBApiClient, clusterID string) error {
111+
func (m AzureBlobMount) Create(client *service.DBApiClient, clusterID string) error {
112112
var confKey string
113113

114114
if m.AuthType == "SAS" {
@@ -139,7 +139,7 @@ dbutils.notebook.exit("success")
139139
}
140140

141141
// Delete deletes a azure blob storage mount given a cluster id
142-
func (m AzureBlobMount) Delete(client service.DBApiClient, clusterID string) error {
142+
func (m AzureBlobMount) Delete(client *service.DBApiClient, clusterID string) error {
143143
iamMountCommand := fmt.Sprintf(`
144144
dbutils.fs.unmount("/mnt/%s")
145145
dbutils.fs.refreshMounts()
@@ -157,7 +157,7 @@ dbutils.notebook.exit("success")
157157
}
158158

159159
// Read verifies a azure blob storage mount given a cluster id
160-
func (m AzureBlobMount) Read(client service.DBApiClient, clusterID string) (string, error) {
160+
func (m AzureBlobMount) Read(client *service.DBApiClient, clusterID string) (string, error) {
161161
iamMountCommand := fmt.Sprintf(`
162162
dbutils.fs.ls("/mnt/%s")
163163
for mount in dbutils.fs.mounts():
@@ -208,7 +208,7 @@ func NewAzureADLSGen1Mount(storageResource string, directory string, mountName s
208208
}
209209

210210
// Create creates a azure datalake gen 1 storage mount given a cluster id
211-
func (m AzureADLSGen1Mount) Create(client service.DBApiClient, clusterID string) error {
211+
func (m AzureADLSGen1Mount) Create(client *service.DBApiClient, clusterID string) error {
212212
iamMountCommand := fmt.Sprintf(`
213213
try:
214214
configs = {"%s.oauth2.access.token.provider.type": "ClientCredential",
@@ -237,7 +237,7 @@ dbutils.notebook.exit("success")
237237
}
238238

239239
// Delete deletes a azure datalake gen 1 storage mount given a cluster id
240-
func (m AzureADLSGen1Mount) Delete(client service.DBApiClient, clusterID string) error {
240+
func (m AzureADLSGen1Mount) Delete(client *service.DBApiClient, clusterID string) error {
241241
iamMountCommand := fmt.Sprintf(`
242242
dbutils.fs.unmount("/mnt/%s")
243243
dbutils.fs.refreshMounts()
@@ -255,7 +255,7 @@ dbutils.notebook.exit("success")
255255
}
256256

257257
// Read verifies the azure datalake gen 1 storage mount given a cluster id
258-
func (m AzureADLSGen1Mount) Read(client service.DBApiClient, clusterID string) (string, error) {
258+
func (m AzureADLSGen1Mount) Read(client *service.DBApiClient, clusterID string) (string, error) {
259259
iamMountCommand := fmt.Sprintf(`
260260
dbutils.fs.ls("/mnt/%s")
261261
for mount in dbutils.fs.mounts():
@@ -306,7 +306,7 @@ func NewAzureADLSGen2Mount(containerName string, storageAccountName string, dire
306306
}
307307

308308
// Create creates a azure datalake gen 2 storage mount
309-
func (m AzureADLSGen2Mount) Create(client service.DBApiClient, clusterID string) error {
309+
func (m AzureADLSGen2Mount) Create(client *service.DBApiClient, clusterID string) error {
310310
iamMountCommand := fmt.Sprintf(`
311311
try:
312312
configs = {"fs.azure.account.auth.type": "OAuth",
@@ -339,7 +339,7 @@ dbutils.notebook.exit("success")
339339
}
340340

341341
// Delete deletes a azure datalake gen 2 storage mount
342-
func (m AzureADLSGen2Mount) Delete(client service.DBApiClient, clusterID string) error {
342+
func (m AzureADLSGen2Mount) Delete(client *service.DBApiClient, clusterID string) error {
343343
iamMountCommand := fmt.Sprintf(`
344344
dbutils.fs.unmount("/mnt/%s")
345345
dbutils.fs.refreshMounts()
@@ -357,7 +357,7 @@ dbutils.notebook.exit("success")
357357
}
358358

359359
// Read verifies the azure datalake gen 2 storage mount
360-
func (m AzureADLSGen2Mount) Read(client service.DBApiClient, clusterID string) (string, error) {
360+
func (m AzureADLSGen2Mount) Read(client *service.DBApiClient, clusterID string) (string, error) {
361361
iamMountCommand := fmt.Sprintf(`
362362
dbutils.fs.ls("/mnt/%s")
363363
for mount in dbutils.fs.mounts():

databricks/provider.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,14 @@ func providerConfigureAzureClient(d *schema.ResourceData, providerVersion string
187187
AdbAccessToken: "",
188188
AdbPlatformToken: "",
189189
}
190-
log.Println("Running Azure Auth")
191-
return azureAuthSetup.initWorkspaceAndGetClient(config)
190+
191+
// Setup the CustomAuthorizer Function to be called at API invoke rather than client invoke
192+
config.CustomAuthorizer = func(config *service.DBApiClientConfig) error {
193+
return azureAuthSetup.initWorkspaceAndGetClient(config)
194+
}
195+
var dbClient service.DBApiClient
196+
dbClient.SetConfig(config)
197+
return &dbClient, nil
192198
}
193199

194200
func providerConfigure(d *schema.ResourceData, providerVersion string) (interface{}, error) {
@@ -214,5 +220,5 @@ func providerConfigure(d *schema.ResourceData, providerVersion string) (interfac
214220
config.UserAgent = fmt.Sprintf("databricks-tf-provider-%s", providerVersion)
215221
var dbClient service.DBApiClient
216222
dbClient.SetConfig(&config)
217-
return dbClient, nil
223+
return &dbClient, nil
218224
}

0 commit comments

Comments
 (0)