diff --git a/acceptance/bin/browser.py b/acceptance/bin/browser.py new file mode 100755 index 0000000000..e849352326 --- /dev/null +++ b/acceptance/bin/browser.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +""" +This script fetches a URL. +It follows redirects if applicable. + +Usage: browser.py +""" + +import urllib.request +import sys + +if len(sys.argv) < 2: + sys.stderr.write("Usage: browser.py \n") + sys.exit(1) + +url = sys.argv[1] +try: + response = urllib.request.urlopen(url) + if response.status != 200: + sys.stderr.write(f"Failed to fetch URL: {url} (status {response.status})\n") + sys.exit(1) +except Exception as e: + sys.stderr.write(f"Failed to fetch URL: {url} ({e})\n") + sys.exit(1) + +sys.exit(0) diff --git a/acceptance/cmd/auth/login/out.databrickscfg b/acceptance/cmd/auth/login/out.databrickscfg new file mode 100644 index 0000000000..99c7d54d1e --- /dev/null +++ b/acceptance/cmd/auth/login/out.databrickscfg @@ -0,0 +1,6 @@ +; The profile defined in the DEFAULT section is to be used as a fallback when no profile is explicitly specified. +[DEFAULT] + +[test] +host = [DATABRICKS_URL] +auth_type = databricks-cli diff --git a/acceptance/cmd/auth/login/out.test.toml b/acceptance/cmd/auth/login/out.test.toml new file mode 100644 index 0000000000..e092fd5ed6 --- /dev/null +++ b/acceptance/cmd/auth/login/out.test.toml @@ -0,0 +1,5 @@ +Local = true +Cloud = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct-exp"] diff --git a/acceptance/cmd/auth/login/output.txt b/acceptance/cmd/auth/login/output.txt new file mode 100644 index 0000000000..b42bbd5527 --- /dev/null +++ b/acceptance/cmd/auth/login/output.txt @@ -0,0 +1,7 @@ + +>>> [CLI] auth login --host [DATABRICKS_URL] --profile test +Profile test was successfully saved + +>>> [CLI] auth profiles +Name Host Valid +test [DATABRICKS_URL] YES diff --git a/acceptance/cmd/auth/login/script b/acceptance/cmd/auth/login/script new file mode 100644 index 0000000000..814ff5876b --- /dev/null +++ b/acceptance/cmd/auth/login/script @@ -0,0 +1,11 @@ +sethome "./home" + +# Use a fake browser that performs a GET on the authorization URL +# and follows the redirect back to localhost. +export BROWSER="browser.py" + +trace $CLI auth login --host $DATABRICKS_HOST --profile test +trace $CLI auth profiles + +# Track the .databrickscfg file that was created to surface changes. +mv "./home/.databrickscfg" "./out.databrickscfg" diff --git a/acceptance/cmd/auth/login/test.toml b/acceptance/cmd/auth/login/test.toml new file mode 100644 index 0000000000..36c0e7e237 --- /dev/null +++ b/acceptance/cmd/auth/login/test.toml @@ -0,0 +1,3 @@ +Ignore = [ + "home" +] diff --git a/acceptance/script.prepare b/acceptance/script.prepare index ea348ee440..33b5e99034 100644 --- a/acceptance/script.prepare +++ b/acceptance/script.prepare @@ -97,3 +97,14 @@ envsubst() { print_telemetry_bool_values() { jq -r 'select(.path? == "/telemetry-ext") | (.body.protoLogs // [])[] | fromjson | ( (.entry // .) | (.databricks_cli_log.bundle_deploy_event.experimental.bool_values // []) ) | map("\(.key) \(.value)") | .[]' out.requests.txt | sort } + +sethome() { + local home="$1" + mkdir -p "$home" + + # For macOS and Linux, use HOME. + export HOME="$home" + + # For Windows, use USERPROFILE. + export USERPROFILE="$home" +} diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 12f4cee14b..5c10aa6620 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -14,9 +14,12 @@ import ( "github.com/databricks/cli/libs/databrickscfg" "github.com/databricks/cli/libs/databrickscfg/cfgpickers" "github.com/databricks/cli/libs/databrickscfg/profile" + "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/exec" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/credentials/u2m" + browserpkg "github.com/pkg/browser" "github.com/spf13/cobra" ) @@ -136,7 +139,11 @@ depends on the existing profiles you have set in your configuration file if err != nil { return err } - persistentAuth, err := u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(oauthArgument)) + persistentAuthOpts := []u2m.PersistentAuthOption{ + u2m.WithOAuthArgument(oauthArgument), + } + persistentAuthOpts = append(persistentAuthOpts, u2m.WithBrowser(getBrowserFunc(cmd))) + persistentAuth, err := u2m.NewPersistentAuth(ctx, persistentAuthOpts...) if err != nil { return err } @@ -288,3 +295,38 @@ func loadProfileByName(ctx context.Context, profileName string, profiler profile } return nil, nil } + +// getBrowserFunc returns a function that opens the given URL in the browser. +// It respects the BROWSER environment variable: +// - empty string: uses the default browser +// - "none": prints the URL to stdout without opening a browser +// - custom command: executes the specified command with the URL as argument +func getBrowserFunc(cmd *cobra.Command) func(url string) error { + browser := env.Get(cmd.Context(), "BROWSER") + switch browser { + case "": + return browserpkg.OpenURL + case "none": + return func(url string) error { + fmt.Fprintf(cmd.OutOrStdout(), "Please open %s in the browser to continue authentication\n", url) + return nil + } + default: + return func(url string) error { + // Run the browser command via a shell. + // It can be a script or a binary and scripts cannot be executed directly on Windows. + e, err := exec.NewCommandExecutor(".") + if err != nil { + return err + } + + e.WithInheritOutput() + cmd, err := e.StartCommand(cmd.Context(), fmt.Sprintf("%q %q", browser, url)) + if err != nil { + return err + } + + return cmd.Wait() + } + } +} diff --git a/libs/testserver/fake_oidc.go b/libs/testserver/fake_oidc.go new file mode 100644 index 0000000000..db6a682c9b --- /dev/null +++ b/libs/testserver/fake_oidc.go @@ -0,0 +1,60 @@ +package testserver + +import ( + "net/http" + "net/url" +) + +// FakeOidc holds OAuth state for acceptance tests. +type FakeOidc struct { + url string +} + +func (s *FakeOidc) OidcEndpoints() Response { + return Response{ + Body: map[string]string{ + "authorization_endpoint": s.url + "/oidc/v1/authorize", + "token_endpoint": s.url + "/oidc/v1/token", + }, + } +} + +func (s *FakeOidc) OidcAuthorize(req Request) Response { + redirectURI, err := url.Parse(req.URL.Query().Get("redirect_uri")) + if err != nil { + return Response{ + StatusCode: http.StatusBadRequest, + Body: err.Error(), + } + } + + // Compile query parameters for the redirect URL. + q := make(url.Values) + + // Include an opaque authorization code that will be used to exchange for a token. + q.Set("code", "oauth-code") + + // Include the state parameter from the original request. + q.Set("state", req.URL.Query().Get("state")) + + // Update the redirect URI with the new query parameters. + redirectURI.RawQuery = q.Encode() + + return Response{ + StatusCode: http.StatusMovedPermanently, + Headers: map[string][]string{ + "Location": {redirectURI.String()}, + }, + } +} + +func (s *FakeOidc) OidcToken(req Request) Response { + return Response{ + Body: map[string]string{ + "access_token": "oauth-token", + "expires_in": "3600", + "scope": "all-apis", + "token_type": "Bearer", + }, + } +} diff --git a/libs/testserver/handlers.go b/libs/testserver/handlers.go index 2c6343100d..2d9295e7fb 100644 --- a/libs/testserver/handlers.go +++ b/libs/testserver/handlers.go @@ -222,19 +222,15 @@ func AddDefaultHandlers(server *Server) { }) server.Handle("GET", "/oidc/.well-known/oauth-authorization-server", func(_ Request) any { - return map[string]string{ - "authorization_endpoint": server.URL + "oidc/v1/authorize", - "token_endpoint": server.URL + "/oidc/v1/token", - } + return server.fakeOidc.OidcEndpoints() }) - server.Handle("POST", "/oidc/v1/token", func(_ Request) any { - return map[string]string{ - "access_token": "oauth-token", - "expires_in": "3600", - "scope": "all-apis", - "token_type": "Bearer", - } + server.Handle("GET", "/oidc/v1/authorize", func(req Request) any { + return server.fakeOidc.OidcAuthorize(req) + }) + + server.Handle("POST", "/oidc/v1/token", func(req Request) any { + return server.fakeOidc.OidcToken(req) }) server.Handle("POST", "/telemetry-ext", func(_ Request) any { diff --git a/libs/testserver/server.go b/libs/testserver/server.go index aa107b4cef..8b8e346e99 100644 --- a/libs/testserver/server.go +++ b/libs/testserver/server.go @@ -25,6 +25,7 @@ type Server struct { t testutil.TestingT fakeWorkspaces map[string]*FakeWorkspace + fakeOidc *FakeOidc mu sync.Mutex RequestCallback func(request *Request) @@ -190,6 +191,7 @@ func New(t testutil.TestingT) *Server { Router: router, t: t, fakeWorkspaces: map[string]*FakeWorkspace{}, + fakeOidc: &FakeOidc{url: server.URL}, } // Set up the not found handler as fallback