Skip to content

Commit 07af6d4

Browse files
authored
feat(go/databricks): OAuth M2M and U2M (#88)
* OAuth M2M and U2M * fix typo
1 parent 4079e17 commit 07af6d4

File tree

2 files changed

+88
-18
lines changed

2 files changed

+88
-18
lines changed

go/adbc/driver/databricks/database.go

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ import (
3434
"github.com/apache/arrow-adbc/go/adbc"
3535
"github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
3636
dbsql "github.com/databricks/databricks-sql-go"
37+
"github.com/databricks/databricks-sql-go/auth/oauth/m2m"
38+
"github.com/databricks/databricks-sql-go/auth/oauth/u2m"
3739
)
3840

3941
const (
@@ -52,7 +54,6 @@ type databaseImpl struct {
5254
// Connection parameters
5355
serverHostname string
5456
httpPath string
55-
accessToken string
5657
port int
5758
catalog string
5859
schema string
@@ -69,10 +70,15 @@ type databaseImpl struct {
6970
sslCertPool *x509.CertPool
7071
sslInsecure bool
7172

72-
// OAuth options (for future expansion)
73+
// Auth
74+
authType string
75+
76+
accessToken string
77+
7378
oauthClientID string
7479
oauthClientSecret string
7580
oauthRefreshToken string
81+
oauthU2MTimeout time.Duration
7682
}
7783

7884
func (d *databaseImpl) resolveConnectionOptions() ([]dbsql.ConnOption, error) {
@@ -90,20 +96,60 @@ func (d *databaseImpl) resolveConnectionOptions() ([]dbsql.ConnOption, error) {
9096
}
9197
}
9298

93-
// FIXME: Support other auth methods
94-
if d.accessToken == "" {
95-
return nil, adbc.Error{
96-
Code: adbc.StatusInvalidArgument,
97-
Msg: "access token is required",
98-
}
99-
}
100-
10199
opts := []dbsql.ConnOption{
102-
dbsql.WithAccessToken(d.accessToken),
103100
dbsql.WithServerHostname(d.serverHostname),
104101
dbsql.WithHTTPPath(d.httpPath),
105102
}
106103

104+
// Handle Auth configurations and validate based on user selected auth type
105+
switch d.authType {
106+
case OptionValueAuthTypePAT:
107+
if d.accessToken == "" {
108+
return nil, adbc.Error{
109+
Code: adbc.StatusInvalidArgument,
110+
Msg: fmt.Sprintf("access token is required when using auth type '%s'. Set this via '%s'.", OptionValueAuthTypePAT, OptionAccessToken),
111+
}
112+
}
113+
opts = append(opts, dbsql.WithAccessToken(d.accessToken))
114+
case OptionValueAuthTypeOAuthM2M:
115+
if d.oauthClientID == "" {
116+
return nil, adbc.Error{
117+
Code: adbc.StatusInvalidArgument,
118+
Msg: fmt.Sprintf("client ID is required when using auth type '%s'. Set this via '%s'.", OptionValueAuthTypeOAuthM2M, OptionOAuthClientID),
119+
}
120+
}
121+
if d.oauthClientSecret == "" {
122+
return nil, adbc.Error{
123+
Code: adbc.StatusInvalidArgument,
124+
Msg: fmt.Sprintf("client secret is required when using auth type '%s'. Set this via '%s'.", OptionValueAuthTypeOAuthM2M, OptionOAuthClientSecret),
125+
}
126+
}
127+
authenticator := m2m.NewAuthenticator(
128+
d.oauthClientID,
129+
d.oauthClientSecret,
130+
d.serverHostname,
131+
)
132+
opts = append(opts, dbsql.WithAuthenticator(authenticator))
133+
case OptionValueAuthTypeExternalBrowser:
134+
timeout := d.oauthU2MTimeout
135+
if timeout == 0 {
136+
timeout = DefaultExternalBrowserTimeout
137+
}
138+
authenticator, err := u2m.NewAuthenticator(d.serverHostname, timeout)
139+
if err != nil {
140+
return nil, adbc.Error{
141+
Code: adbc.StatusInvalidState,
142+
Msg: fmt.Sprintf("failed to initialize authenticator: %v", err),
143+
}
144+
}
145+
opts = append(opts, dbsql.WithAuthenticator(authenticator))
146+
default:
147+
return nil, adbc.Error{
148+
Code: adbc.StatusInvalidArgument,
149+
Msg: fmt.Sprintf("missing required option: '%s'", OptionAuthType),
150+
}
151+
}
152+
107153
// Validate and set custom port
108154
// Defaults to 443
109155
if d.port != 0 {
@@ -296,6 +342,8 @@ func (d *databaseImpl) SetOption(key, value string) error {
296342
// We need to re-initialize the db/connection pool if options change
297343
d.needsRefresh = true
298344
switch key {
345+
case OptionAuthType:
346+
d.authType = value
299347
case OptionServerHostname:
300348
d.serverHostname = value
301349
case OptionHTTPPath:
@@ -412,6 +460,17 @@ func (d *databaseImpl) SetOption(key, value string) error {
412460
d.oauthClientSecret = value
413461
case OptionOAuthRefreshToken:
414462
d.oauthRefreshToken = value
463+
case OptionExternalBrowserTimeout:
464+
if value != "" {
465+
timeout, err := time.ParseDuration(value)
466+
if err != nil {
467+
return adbc.Error{
468+
Code: adbc.StatusInvalidArgument,
469+
Msg: fmt.Sprintf("invalid external browser auth timeout: %v", err),
470+
}
471+
}
472+
d.oauthU2MTimeout = timeout
473+
}
415474
default:
416475
return d.DatabaseImplBase.SetOption(key, value)
417476
}

go/adbc/driver/databricks/driver.go

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ package databricks
3232
import (
3333
"context"
3434
"runtime/debug"
35+
"time"
3536

3637
"github.com/apache/arrow-adbc/go/adbc"
3738
"github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
@@ -42,7 +43,6 @@ const (
4243
// Connection options
4344
OptionServerHostname = "databricks.server_hostname"
4445
OptionHTTPPath = "databricks.http_path"
45-
OptionAccessToken = "databricks.access_token"
4646
OptionPort = "databricks.port"
4747
OptionCatalog = "databricks.catalog"
4848
OptionSchema = "databricks.schema"
@@ -57,14 +57,25 @@ const (
5757
OptionSSLMode = "databricks.ssl_mode"
5858
OptionSSLRootCert = "databricks.ssl_root_cert"
5959

60-
// OAuth options (for future expansion)
61-
OptionOAuthClientID = "databricks.oauth.client_id"
62-
OptionOAuthClientSecret = "databricks.oauth.client_secret"
63-
OptionOAuthRefreshToken = "databricks.oauth.refresh_token"
60+
// Auth: Type
61+
OptionAuthType = "databricks.auth_type"
62+
OptionValueAuthTypeOAuthM2M = "oauth-m2m"
63+
OptionValueAuthTypeExternalBrowser = "external-browser"
64+
OptionValueAuthTypePAT = "pat"
65+
66+
// Auth: OAuth
67+
OptionOAuthClientID = "databricks.oauth.client_id"
68+
OptionOAuthClientSecret = "databricks.oauth.client_secret"
69+
OptionOAuthRefreshToken = "databricks.oauth.refresh_token"
70+
OptionExternalBrowserTimeout = "databricks.oauth.external_browser.timeout"
71+
72+
// Auth: PAT
73+
OptionAccessToken = "databricks.access_token"
6474

6575
// Default values
66-
DefaultPort = 443
67-
DefaultSSLMode = "require"
76+
DefaultPort = 443
77+
DefaultSSLMode = "require"
78+
DefaultExternalBrowserTimeout = 1 * time.Minute
6879
)
6980

7081
var (

0 commit comments

Comments
 (0)