Skip to content

Commit 9cd7417

Browse files
authored
Added experimental authentication methods (#1660)
1 parent d0fecab commit 9cd7417

File tree

10 files changed

+356
-4
lines changed

10 files changed

+356
-4
lines changed

common/client.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ type DatabricksClient struct {
4242
Username string `name:"username" env:"DATABRICKS_USERNAME" auth:"password"`
4343
Password string `name:"password" env:"DATABRICKS_PASSWORD" auth:"password,sensitive"`
4444

45+
ClientID string `name:"client_id" env:"DATABRICKS_CLIENT_ID" auth:"oauth"`
46+
ClientSecret string `name:"client_secret" env:"DATABRICKS_CLIENT_SECRET" auth:"oauth,sensitive"`
47+
TokenEndpoint string `name:"token_endpoint" env:"DATABRICKS_TOKEN_ENDPOINT" auth:"oauth"`
48+
4549
// Databricks Account ID for Accounts API. This field is used in dependencies.
4650
AccountID string `name:"account_id" env:"DATABRICKS_ACCOUNT_ID"`
4751

@@ -251,6 +255,7 @@ func (c *DatabricksClient) Authenticate(ctx context.Context) error {
251255
providers := []auth{
252256
{c.configureWithPat, "pat"},
253257
{c.configureWithBasicAuth, "basic"},
258+
{c.configureWithOAuthM2M, "oauth-m2m"},
254259
{c.configureWithAzureClientSecret, "azure-client-secret"},
255260
{c.configureWithAzureManagedIdentity, "azure-msi"},
256261
{c.configureWithAzureCLI, "azure-cli"},
@@ -526,6 +531,8 @@ func (c *DatabricksClient) ClientForHost(ctx context.Context, url string) (*Data
526531
Username: c.Username,
527532
Password: c.Password,
528533
Token: c.Token,
534+
ClientID: c.ClientID,
535+
ClientSecret: c.ClientSecret,
529536
GoogleServiceAccount: c.GoogleServiceAccount,
530537
GoogleCredentials: c.GoogleCredentials,
531538
AzurermEnvironment: c.AzurermEnvironment,

common/client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func TestDatabricksClient_FormatURL(t *testing.T) {
154154

155155
func TestClientAttributes(t *testing.T) {
156156
ca := ClientAttributes()
157-
assert.Len(t, ca, 22)
157+
assert.Len(t, ca, 25)
158158
}
159159

160160
func TestDatabricksClient_Authenticate(t *testing.T) {

common/gcp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ func (c *DatabricksClient) configureWithGoogleForWorkspace(ctx context.Context)
122122
if err != nil {
123123
return nil, err
124124
}
125-
return newOidcAuthorizerForWorkspace(oidcSource), nil
125+
return newOidcAuthorizerWithJustBearer(oidcSource), nil
126126
}
127127

128-
func newOidcAuthorizerForWorkspace(oidcSource oauth2.TokenSource) func(r *http.Request) error {
128+
func newOidcAuthorizerWithJustBearer(oidcSource oauth2.TokenSource) func(r *http.Request) error {
129129
return func(r *http.Request) error {
130130
oidc, err := oidcSource.Token()
131131
if err != nil {

common/gcp_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ func TestNewOidcAuthorizerForWorkspace(t *testing.T) {
162162
AccessToken: "abc",
163163
TokenType: "Bearer",
164164
}
165-
auth := newOidcAuthorizerForWorkspace(
165+
auth := newOidcAuthorizerWithJustBearer(
166166
oauth2.StaticTokenSource(&token))
167167
request := httptest.NewRequest("GET", "http://localhost", nil)
168168
err := auth(request)

common/http.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,10 @@ func (c *DatabricksClient) recursiveMask(requestMap map[string]any) any {
385385
requestMap[k] = "**REDACTED**"
386386
continue
387387
}
388+
if k == "secret" {
389+
requestMap[k] = "**REDACTED**"
390+
continue
391+
}
388392
if k == "content" {
389393
requestMap[k] = "**REDACTED**"
390394
continue

common/m2m.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package common
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"io"
9+
"log"
10+
"net/http"
11+
12+
"golang.org/x/oauth2"
13+
"golang.org/x/oauth2/clientcredentials"
14+
)
15+
16+
type oauthAuthorizationServer struct {
17+
AuthorizationEndpoint string `json:"authorization_endpoint"`
18+
TokenEndpoint string `json:"token_endpoint"`
19+
}
20+
21+
var errNotAvailable = errors.New("not available")
22+
23+
func (c *DatabricksClient) getOAuthEndpoints() (*oauthAuthorizationServer, error) {
24+
err := c.fixHost()
25+
if err != nil {
26+
return nil, fmt.Errorf("host: %w", err)
27+
}
28+
oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", c.Host)
29+
oidcResponse, err := http.Get(oidc)
30+
if err != nil {
31+
return nil, errNotAvailable
32+
}
33+
if oidcResponse.Body == nil {
34+
return nil, fmt.Errorf("fetch .well-known: empty body")
35+
}
36+
defer oidcResponse.Body.Close()
37+
raw, err := io.ReadAll(oidcResponse.Body)
38+
if err != nil {
39+
return nil, fmt.Errorf("read .well-known: %w", err)
40+
}
41+
var oauthEndpoints oauthAuthorizationServer
42+
err = json.Unmarshal(raw, &oauthEndpoints)
43+
if err != nil {
44+
return nil, fmt.Errorf("parse .well-known: %w", err)
45+
}
46+
return &oauthEndpoints, nil
47+
}
48+
49+
func (c *DatabricksClient) configureWithOAuthM2M(
50+
ctx context.Context) (func(r *http.Request) error, error) {
51+
if !c.IsAws() || c.ClientID == "" || c.ClientSecret == "" || c.Host == "" {
52+
return nil, nil
53+
}
54+
// workaround for accounts endpoint not having yet a well-known OIDC alias
55+
if c.TokenEndpoint == "" {
56+
endpoints, err := c.getOAuthEndpoints()
57+
if err == errNotAvailable {
58+
return nil, nil
59+
}
60+
if err != nil {
61+
return nil, fmt.Errorf("databricks oauth: %w", err)
62+
}
63+
c.TokenEndpoint = endpoints.TokenEndpoint
64+
}
65+
log.Printf("[INFO] Generating Databricks OAuth token for Service Principal (%s)", c.ClientID)
66+
ts := (&clientcredentials.Config{
67+
ClientID: c.ClientID,
68+
ClientSecret: c.ClientSecret,
69+
AuthStyle: oauth2.AuthStyleInHeader,
70+
TokenURL: c.TokenEndpoint,
71+
Scopes: []string{"all-apis"},
72+
}).TokenSource(ctx)
73+
return newOidcAuthorizerWithJustBearer(ts), nil
74+
}

common/m2m_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package common
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestConfigureWithOAuthM2M(t *testing.T) {
14+
defer CleanupEnvironment()()
15+
cnt := []int{0}
16+
server := httptest.NewServer(http.HandlerFunc(
17+
func(rw http.ResponseWriter, req *http.Request) {
18+
if req.RequestURI ==
19+
"/oidc/.well-known/oauth-authorization-server" {
20+
_, err := rw.Write([]byte(
21+
`{"token_endpoint": "http://localhost/oauth/token"}`))
22+
assert.NoError(t, err)
23+
cnt[0]++
24+
return
25+
}
26+
assert.Fail(t, fmt.Sprintf("Received unexpected call: %s %s",
27+
req.Method, req.RequestURI))
28+
}))
29+
defer server.Close()
30+
31+
c := &DatabricksClient{
32+
Host: server.URL,
33+
ClientID: "abc",
34+
ClientSecret: "bcd",
35+
}
36+
_, err := c.configureWithOAuthM2M(context.Background())
37+
assert.NoError(t, err)
38+
assert.Equal(t, 1, cnt[0])
39+
}
40+
41+
func TestConfigureWithOAuthOIDCUnavailableSkips(t *testing.T) {
42+
defer CleanupEnvironment()()
43+
c := &DatabricksClient{
44+
Host: "http://localhost:22",
45+
ClientID: "abc",
46+
ClientSecret: "bcd",
47+
}
48+
_, err := c.configureWithOAuthM2M(context.Background())
49+
assert.NoError(t, err)
50+
}

provider/provider.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ func DatabricksProvider() *schema.Provider {
116116
"databricks_secret_acl": secrets.ResourceSecretACL(),
117117
"databricks_service_principal": scim.ResourceServicePrincipal(),
118118
"databricks_service_principal_role": aws.ResourceServicePrincipalRole(),
119+
"databricks_service_principal_secret": tokens.ResourceServicePrincipalSecret(),
119120
"databricks_sql_dashboard": sql.ResourceSqlDashboard(),
120121
"databricks_sql_endpoint": sql.ResourceSqlEndpoint(),
121122
"databricks_sql_global_config": sql.ResourceSqlGlobalConfig(),
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package tokens
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
8+
"github.com/databricks/terraform-provider-databricks/common"
9+
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
10+
)
11+
12+
type ServicePrincipalSecret struct {
13+
ID string `json:"id,omitempty"`
14+
Secret string `json:"secret,omitempty" tf:"computed,sensitive"`
15+
Status string `json:"status,omitempty" tf:"computed"`
16+
}
17+
18+
type ListServicePrincipalSecrets struct {
19+
Secrets []ServicePrincipalSecret `json:"secrets"`
20+
}
21+
22+
// NewServicePrincipalSecretAPI creates ServicePrincipalSecretAPI instance from provider meta
23+
func NewServicePrincipalSecretAPI(ctx context.Context, m any) ServicePrincipalSecretAPI {
24+
return ServicePrincipalSecretAPI{m.(*common.DatabricksClient), ctx}
25+
}
26+
27+
// ServicePrincipalSecretAPI exposes the API to create client secrets
28+
type ServicePrincipalSecretAPI struct {
29+
client *common.DatabricksClient
30+
context context.Context
31+
}
32+
33+
func (a ServicePrincipalSecretAPI) createServicePrincipalSecret(spnID string) (secret *ServicePrincipalSecret, err error) {
34+
path := fmt.Sprintf("/accounts/%s/servicePrincipals/%s/credentials/secrets", a.client.AccountID, spnID)
35+
err = a.client.Post(a.context, path, map[string]any{}, &secret)
36+
return
37+
}
38+
39+
func (a ServicePrincipalSecretAPI) listServicePrincipalSecrets(spnID string) (secrets ListServicePrincipalSecrets, err error) {
40+
path := fmt.Sprintf("/accounts/%s/servicePrincipals/%s/credentials/secrets", a.client.AccountID, spnID)
41+
err = a.client.Get(a.context, path, nil, &secrets)
42+
return
43+
}
44+
45+
func (a ServicePrincipalSecretAPI) deleteServicePrincipalSecret(spnID, secretID string) error { // FIXME
46+
path := fmt.Sprintf("/accounts/%s/servicePrincipals/%s/credentials/secrets/%s", a.client.AccountID, spnID, secretID)
47+
return a.client.Delete(a.context, path, nil)
48+
}
49+
50+
func ResourceServicePrincipalSecret() *schema.Resource {
51+
spnSecretSchema := common.StructToSchema(ServicePrincipalSecret{},
52+
func(m map[string]*schema.Schema) map[string]*schema.Schema {
53+
m["id"].Computed = true
54+
m["service_principal_id"] = &schema.Schema{
55+
Type: schema.TypeString,
56+
ForceNew: true,
57+
Required: true,
58+
}
59+
return m
60+
})
61+
return common.Resource{
62+
Schema: spnSecretSchema,
63+
Create: func(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) error {
64+
if c.AccountID == "" {
65+
return errors.New("must have `account_id` on provider")
66+
}
67+
idSeen := map[string]bool{}
68+
api := NewServicePrincipalSecretAPI(ctx, c)
69+
spnID := d.Get("service_principal_id").(string)
70+
secrets, err := api.listServicePrincipalSecrets(spnID)
71+
if err != nil {
72+
return err
73+
}
74+
for _, v := range secrets.Secrets {
75+
idSeen[v.ID] = true
76+
}
77+
secret, err := api.createServicePrincipalSecret(spnID)
78+
if err != nil {
79+
return err
80+
}
81+
secrets, err = api.listServicePrincipalSecrets(spnID)
82+
if err != nil {
83+
return err
84+
}
85+
// ugly hack because rpc does not return ID of created secret
86+
for _, v := range secrets.Secrets {
87+
if len(idSeen) > 0 && idSeen[v.ID] {
88+
continue
89+
}
90+
d.SetId(v.ID)
91+
}
92+
return d.Set("secret", secret.Secret)
93+
},
94+
Read: func(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) error {
95+
if c.AccountID == "" {
96+
return errors.New("must have `account_id` on provider")
97+
}
98+
api := NewServicePrincipalSecretAPI(ctx, c)
99+
spnID := d.Get("service_principal_id").(string)
100+
secrets, err := api.listServicePrincipalSecrets(spnID)
101+
if err != nil {
102+
return err
103+
}
104+
for _, v := range secrets.Secrets {
105+
if v.ID != d.Id() {
106+
continue
107+
}
108+
return d.Set("status", v.Status)
109+
}
110+
return common.NotFound("client secret not found")
111+
},
112+
Delete: func(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) error {
113+
if c.AccountID == "" {
114+
return errors.New("must have `account_id` on provider")
115+
}
116+
api := NewServicePrincipalSecretAPI(ctx, c)
117+
spnID := d.Get("service_principal_id").(string)
118+
return api.deleteServicePrincipalSecret(spnID, d.Id())
119+
},
120+
}.ToResource()
121+
}

0 commit comments

Comments
 (0)