@@ -34,6 +34,8 @@ import (
3434 "github.com/apache/arrow-adbc/go/adbc"
3535 "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
3636 dbsql "github.com/databricks/databricks-sql-go"
37+ "github.com/databricks/databricks-sql-go/auth/oauth/m2m"
38+ "github.com/databricks/databricks-sql-go/auth/oauth/u2m"
3739)
3840
3941const (
@@ -52,7 +54,6 @@ type databaseImpl struct {
5254 // Connection parameters
5355 serverHostname string
5456 httpPath string
55- accessToken string
5657 port int
5758 catalog string
5859 schema string
@@ -69,10 +70,15 @@ type databaseImpl struct {
6970 sslCertPool * x509.CertPool
7071 sslInsecure bool
7172
72- // OAuth options (for future expansion)
73+ // Auth
74+ authType string
75+
76+ accessToken string
77+
7378 oauthClientID string
7479 oauthClientSecret string
7580 oauthRefreshToken string
81+ oauthU2MTimeout time.Duration
7682}
7783
7884func (d * databaseImpl ) resolveConnectionOptions () ([]dbsql.ConnOption , error ) {
@@ -90,20 +96,60 @@ func (d *databaseImpl) resolveConnectionOptions() ([]dbsql.ConnOption, error) {
9096 }
9197 }
9298
93- // FIXME: Support other auth methods
94- if d .accessToken == "" {
95- return nil , adbc.Error {
96- Code : adbc .StatusInvalidArgument ,
97- Msg : "access token is required" ,
98- }
99- }
100-
10199 opts := []dbsql.ConnOption {
102- dbsql .WithAccessToken (d .accessToken ),
103100 dbsql .WithServerHostname (d .serverHostname ),
104101 dbsql .WithHTTPPath (d .httpPath ),
105102 }
106103
104+ // Handle Auth configurations and validate based on user selected auth type
105+ switch d .authType {
106+ case OptionValueAuthTypePAT :
107+ if d .accessToken == "" {
108+ return nil , adbc.Error {
109+ Code : adbc .StatusInvalidArgument ,
110+ Msg : fmt .Sprintf ("access token is required when using auth type '%s'. Set this via '%s'." , OptionValueAuthTypePAT , OptionAccessToken ),
111+ }
112+ }
113+ opts = append (opts , dbsql .WithAccessToken (d .accessToken ))
114+ case OptionValueAuthTypeOAuthM2M :
115+ if d .oauthClientID == "" {
116+ return nil , adbc.Error {
117+ Code : adbc .StatusInvalidArgument ,
118+ Msg : fmt .Sprintf ("client ID is required when using auth type '%s'. Set this via '%s'." , OptionValueAuthTypeOAuthM2M , OptionOAuthClientID ),
119+ }
120+ }
121+ if d .oauthClientSecret == "" {
122+ return nil , adbc.Error {
123+ Code : adbc .StatusInvalidArgument ,
124+ Msg : fmt .Sprintf ("client secret is required when using auth type '%s'. Set this via '%s'." , OptionValueAuthTypeOAuthM2M , OptionOAuthClientSecret ),
125+ }
126+ }
127+ authenticator := m2m .NewAuthenticator (
128+ d .oauthClientID ,
129+ d .oauthClientSecret ,
130+ d .serverHostname ,
131+ )
132+ opts = append (opts , dbsql .WithAuthenticator (authenticator ))
133+ case OptionValueAuthTypeExternalBrowser :
134+ timeout := d .oauthU2MTimeout
135+ if timeout == 0 {
136+ timeout = DefaultExternalBrowserTimeout
137+ }
138+ authenticator , err := u2m .NewAuthenticator (d .serverHostname , timeout )
139+ if err != nil {
140+ return nil , adbc.Error {
141+ Code : adbc .StatusInvalidState ,
142+ Msg : fmt .Sprintf ("failed to initialize authenticator: %v" , err ),
143+ }
144+ }
145+ opts = append (opts , dbsql .WithAuthenticator (authenticator ))
146+ default :
147+ return nil , adbc.Error {
148+ Code : adbc .StatusInvalidArgument ,
149+ Msg : fmt .Sprintf ("missing required option: '%s'" , OptionAuthType ),
150+ }
151+ }
152+
107153 // Validate and set custom port
108154 // Defaults to 443
109155 if d .port != 0 {
@@ -296,6 +342,8 @@ func (d *databaseImpl) SetOption(key, value string) error {
296342 // We need to re-initialize the db/connection pool if options change
297343 d .needsRefresh = true
298344 switch key {
345+ case OptionAuthType :
346+ d .authType = value
299347 case OptionServerHostname :
300348 d .serverHostname = value
301349 case OptionHTTPPath :
@@ -412,6 +460,17 @@ func (d *databaseImpl) SetOption(key, value string) error {
412460 d .oauthClientSecret = value
413461 case OptionOAuthRefreshToken :
414462 d .oauthRefreshToken = value
463+ case OptionExternalBrowserTimeout :
464+ if value != "" {
465+ timeout , err := time .ParseDuration (value )
466+ if err != nil {
467+ return adbc.Error {
468+ Code : adbc .StatusInvalidArgument ,
469+ Msg : fmt .Sprintf ("invalid external browser auth timeout: %v" , err ),
470+ }
471+ }
472+ d .oauthU2MTimeout = timeout
473+ }
415474 default :
416475 return d .DatabaseImplBase .SetOption (key , value )
417476 }
0 commit comments