Skip to content
Merged
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### New Features and Improvements

* Add support for discovery URL for browser based authentication flow.

### Bug Fixes

### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
import java.time.Duration;
import java.util.*;
import org.apache.http.HttpMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DatabricksConfig {
private static final Logger LOG = LoggerFactory.getLogger(DatabricksConfig.class);
private CredentialsProvider credentialsProvider = new DefaultCredentialsProvider();

@ConfigAttribute(env = "DATABRICKS_HOST")
Expand Down Expand Up @@ -239,7 +242,7 @@ public TokenSource getTokenSource() {
return (TokenSource) headerFactory;
}
return new ErrorTokenSource(
String.format("OAuth Token not supported for current auth type %s", authType));
String.format("OAuth Token not supported for current auth type %s", authType));
}

public CredentialsProvider getCredentialsProvider() {
Expand Down Expand Up @@ -431,13 +434,17 @@ public DatabricksConfig setAzureUseMsi(boolean azureUseMsi) {
return this;
}

/** @deprecated Use {@link #getAzureUseMsi()} instead. */
/**
* @deprecated Use {@link #getAzureUseMsi()} instead.
*/
@Deprecated()
public boolean getAzureUseMSI() {
return azureUseMsi;
}

/** @deprecated Use {@link #getAzureUseMsi()} instead. */
/**
* @deprecated Use {@link #getAzureUseMsi()} instead.
*/
@Deprecated
public DatabricksConfig setAzureUseMSI(boolean azureUseMsi) {
this.azureUseMsi = azureUseMsi;
Expand Down Expand Up @@ -647,7 +654,19 @@ public OpenIDConnectEndpoints getOidcEndpoints() throws IOException {
if (discoveryUrl == null) {
return fetchDefaultOidcEndpoints();
}
return fetchOidcEndpointsFromDiscovery();
try {
OpenIDConnectEndpoints oidcEndpoints = fetchOidcEndpointsFromDiscovery();
if (oidcEndpoints != null) {
return oidcEndpoints;
}
} catch (Exception e) {
LOG.warn(
"Failed to fetch OIDC Endpoints using discovery URL: {}. Error: {}. \nDefaulting to fetch OIDC using default endpoint.",
discoveryUrl,
e.getMessage(),
e);
}
return fetchDefaultOidcEndpoints();
}

private OpenIDConnectEndpoints fetchOidcEndpointsFromDiscovery() {
Expand Down Expand Up @@ -676,22 +695,22 @@ private OpenIDConnectEndpoints fetchDefaultOidcEndpoints() throws IOException {
return null;
}
return new OpenIDConnectEndpoints(
realAuthUrl.replaceAll("/authorize", "/token"), realAuthUrl);
realAuthUrl.replaceAll("/authorize", "/token"), realAuthUrl);
}
if (isAccountClient() && getAccountId() != null) {
String prefix = getHost() + "/oidc/accounts/" + getAccountId();
return new OpenIDConnectEndpoints(prefix + "/v1/token", prefix + "/v1/authorize");
}

ApiClient apiClient =
new ApiClient.Builder()
.withHttpClient(getHttpClient())
.withGetHostFunc(v -> getHost())
.build();
new ApiClient.Builder()
.withHttpClient(getHttpClient())
.withGetHostFunc(v -> getHost())
.build();
try {
return apiClient.execute(
new Request("GET", "/oidc/.well-known/oauth-authorization-server"),
OpenIDConnectEndpoints.class);
new Request("GET", "/oidc/.well-known/oauth-authorization-server"),
OpenIDConnectEndpoints.class);
} catch (IOException e) {
throw new DatabricksException("IO error: " + e.getMessage(), e);
}
Expand Down Expand Up @@ -737,6 +756,7 @@ public DatabricksEnvironment getDatabricksEnvironment() {
}

private DatabricksConfig clone(Set<String> fieldsToSkip) {
fieldsToSkip.add("LOG");
DatabricksConfig newConfig = new DatabricksConfig();
for (Field f : DatabricksConfig.class.getDeclaredFields()) {
if (fieldsToSkip.contains(f.getName())) {
Expand All @@ -757,18 +777,18 @@ public DatabricksConfig clone() {

public DatabricksConfig newWithWorkspaceHost(String host) {
Set<String> fieldsToSkip =
new HashSet<>(
Arrays.asList(
// The config for WorkspaceClient has a different host and Azure Workspace resource
// ID, and also omits
// the account ID.
"host",
"accountId",
"azureWorkspaceResourceId",
// For cloud-native OAuth, we need to reauthenticate as the audience has changed, so
// don't cache the
// header factory.
"headerFactory"));
new HashSet<>(
Arrays.asList(
// The config for WorkspaceClient has a different host and Azure Workspace resource
// ID, and also omits
// the account ID.
"host",
"accountId",
"azureWorkspaceResourceId",
// For cloud-native OAuth, we need to reauthenticate as the audience has changed, so
// don't cache the
// header factory.
"headerFactory"));
return clone(fieldsToSkip).setHost(host);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ CachedTokenSource performBrowserAuth(
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
.withBrowserTimeout(config.getOAuthBrowserAuthTimeout())
.withScopes(new ArrayList<>(scopes))
.withOpenIDConnectEndpoints(config.getOidcEndpoints())
.build();
Consent consent = client.initiateConsent();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public static class Builder {
private HttpClient hc;
private String accountId;
private Optional<Duration> browserTimeout = Optional.empty();
private OpenIDConnectEndpoints openIDConnectEndpoints;

public Builder() {}

Expand All @@ -51,6 +52,11 @@ public Builder withHttpClient(HttpClient hc) {
return this;
}

public Builder withOpenIDConnectEndpoints(OpenIDConnectEndpoints openIDConnectEndpoints) {
this.openIDConnectEndpoints = openIDConnectEndpoints;
return this;
}

public Builder withHost(String host) {
this.host = host;
return this;
Expand Down Expand Up @@ -102,6 +108,7 @@ public Builder withBrowserTimeout(Duration browserTimeout) {
private final SecureRandom random = new SecureRandom();
private final boolean isAws;
private final boolean isAzure;
private final OpenIDConnectEndpoints openIDConnectEndpoints;
private final Optional<Duration> browserTimeout;

private OAuthClient(Builder b) throws IOException {
Expand All @@ -112,16 +119,16 @@ private OAuthClient(Builder b) throws IOException {
this.hc = b.hc;

DatabricksConfig config =
new DatabricksConfig().setHost(b.host).setAccountId(b.accountId).resolve();
OpenIDConnectEndpoints oidc = config.getOidcEndpoints();
if (oidc == null) {
new DatabricksConfig().setHost(b.host).setAccountId(b.accountId).resolve();
openIDConnectEndpoints = b.openIDConnectEndpoints;
if (openIDConnectEndpoints == null) {
throw new DatabricksException(b.host + " does not support OAuth");
}

this.isAws = config.isAws();
this.isAzure = config.isAzure();
this.tokenUrl = oidc.getTokenEndpoint();
this.authUrl = oidc.getAuthorizationEndpoint();
this.tokenUrl = openIDConnectEndpoints.getTokenEndpoint();
this.authUrl = openIDConnectEndpoints.getAuthorizationEndpoint();
this.browserTimeout = b.browserTimeout;
this.scopes = b.scopes;
}
Expand All @@ -138,6 +145,10 @@ public String getClientSecret() {
return clientSecret;
}

public OpenIDConnectEndpoints getOidcEndpoints() {
return openIDConnectEndpoints;
}

public String getRedirectUrl() {
return redirectUrl;
}
Expand Down Expand Up @@ -179,9 +190,9 @@ private static byte[] sha256(byte[] input) {

private static String urlEncode(String urlBase, Map<String, String> params) {
String queryParams =
params.entrySet().stream()
.map(entry -> entry.getKey() + "=" + entry.getValue())
.collect(Collectors.joining("&"));
params.entrySet().stream()
.map(entry -> entry.getKey() + "=" + entry.getValue())
.collect(Collectors.joining("&"));
return urlBase + "?" + queryParams.replaceAll(" ", "%20");
}

Expand All @@ -203,15 +214,15 @@ public Consent initiateConsent() throws MalformedURLException {
String url = urlEncode(authUrl, params);

return new Consent.Builder()
.withClientId(clientId)
.withClientSecret(clientSecret)
.withAuthUrl(url)
.withTokenUrl(tokenUrl)
.withRedirectUrl(redirectUrl)
.withState(state)
.withVerifier(verifier)
.withHttpClient(hc)
.withBrowserTimeout(browserTimeout)
.build();
.withClientId(clientId)
.withClientSecret(clientSecret)
.withAuthUrl(url)
.withTokenUrl(tokenUrl)
.withRedirectUrl(redirectUrl)
.withState(state)
.withVerifier(verifier)
.withHttpClient(hc)
.withBrowserTimeout(browserTimeout)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,35 @@ public void testDiscoveryEndpoint() throws IOException {
}
}

@Test
public void testDiscoveryEndpointFetchFallback() throws IOException {
String discoveryUrlSuffix = "/test.discovery.url";
String OIDCResponse =
"{\n"
+ " \"authorization_endpoint\": \"https://test.auth.endpoint/oidc/v1/authorize\",\n"
+ " \"token_endpoint\": \"https://test.auth.endpoint/oidc/v1/token\"\n"
+ "}";

try (FixtureServer server =
new FixtureServer()
.with("GET", discoveryUrlSuffix, "", 400)
.with("GET", "/oidc/.well-known/oauth-authorization-server", OIDCResponse, 200)) {

String discoveryUrl = server.getUrl() + discoveryUrlSuffix;

OpenIDConnectEndpoints oidcEndpoints =
new DatabricksConfig()
.setHost(server.getUrl())
.setDiscoveryUrl(discoveryUrl)
.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build())
.getOidcEndpoints();

assertEquals(
"https://test.auth.endpoint/oidc/v1/authorize", oidcEndpoints.getAuthorizationEndpoint());
assertEquals("https://test.auth.endpoint/oidc/v1/token", oidcEndpoints.getTokenEndpoint());
}
}

@Test
public void testNewWithWorkspaceHost() {
DatabricksConfig config =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void clientAndConsentTest() throws IOException {
.withClientId(config.getClientId())
.withClientSecret(config.getClientSecret())
.withHost(config.getHost())
.withOpenIDConnectEndpoints(config.getOidcEndpoints())
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
.withScopes(config.getScopes())
.build();
Expand Down Expand Up @@ -94,6 +95,7 @@ void clientAndConsentTestWithCustomRedirectUrl() throws IOException {
.withClientId(config.getClientId())
.withClientSecret(config.getClientSecret())
.withHost(config.getHost())
.withOpenIDConnectEndpoints(config.getOidcEndpoints())
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
.withScopes(config.getScopes())
.build();
Expand Down
Loading