Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### New Features and Improvements

* Add option to add a timeout for browser confirmation in the U2M authentication flow.

### Bug Fixes

* User provided scopes are now properly propagated in OAuth flows.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.databricks.sdk.core;

import java.lang.reflect.Field;
import java.time.Duration;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -41,16 +42,25 @@ public void setValueOnConfig(DatabricksConfig cfg, String value) throws IllegalA
// workspace clients or config resolution) are safe
synchronized (field) {
field.setAccessible(true);
if (value == null || value.trim().isEmpty()) {
return;
}

if (field.getType() == String.class) {
field.set(cfg, value);
} else if (field.getType() == int.class) {
field.set(cfg, Integer.parseInt(value));
} else if (field.getType() == Integer.class) {
field.set(cfg, Integer.parseInt(value));
} else if (field.getType() == boolean.class) {
field.set(cfg, Boolean.parseBoolean(value));
} else if (field.getType() == Boolean.class) {
field.set(cfg, Boolean.parseBoolean(value));
} else if (field.getType() == Duration.class) {
int seconds = Integer.parseInt(value);
field.set(cfg, seconds > 0 ? Duration.ofSeconds(seconds) : null);
} else if (field.getType() == ProxyConfig.ProxyAuthType.class) {
if (value != null) {
field.set(cfg, ProxyConfig.ProxyAuthType.valueOf(value));
}
field.set(cfg, ProxyConfig.ProxyAuthType.valueOf(value));
}
field.setAccessible(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.time.Duration;
import java.util.*;
import org.apache.http.HttpMessage;

Expand Down Expand Up @@ -163,6 +164,13 @@ public class DatabricksConfig {
@ConfigAttribute(env = "DATABRICKS_DISABLE_ASYNC_TOKEN_REFRESH")
private Boolean disableAsyncTokenRefresh;

/**
* The duration to wait for a browser response during U2M authentication before timing out. If set
* to 0 or null, the connector waits for an indefinite amount of time.
*/
@ConfigAttribute(env = "DATABRICKS_OAUTH_BROWSER_AUTH_TIMEOUT")
private Duration oauthBrowserAuthTimeout;

public Environment getEnv() {
return env;
}
Expand Down Expand Up @@ -597,6 +605,15 @@ public DatabricksConfig setDisableAsyncTokenRefresh(boolean disableAsyncTokenRef
return this;
}

public Duration getOAuthBrowserAuthTimeout() {
return oauthBrowserAuthTimeout;
}

public DatabricksConfig setOAuthBrowserAuthTimeout(Duration oauthBrowserAuthTimeout) {
this.oauthBrowserAuthTimeout = oauthBrowserAuthTimeout;
return this;
}

public boolean isAzure() {
if (azureWorkspaceResourceId != null) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
import java.io.Serializable;
import java.net.*;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -53,6 +55,7 @@ public class Consent implements Serializable {
private final String redirectUrl;
private final String clientId;
private final String clientSecret;
private final Optional<Duration> browserTimeout;

public static class Builder {
private HttpClient hc = new CommonsHttpClient.Builder().withTimeoutSeconds(30).build();
Expand All @@ -63,6 +66,7 @@ public static class Builder {
private String redirectUrl;
private String clientId;
private String clientSecret;
private Optional<Duration> browserTimeout = Optional.empty();

public Builder withHttpClient(HttpClient hc) {
this.hc = hc;
Expand Down Expand Up @@ -104,6 +108,11 @@ public Builder withClientSecret(String clientSecret) {
return this;
}

public Builder withBrowserTimeout(Optional<Duration> browserTimeout) {
this.browserTimeout = browserTimeout;
return this;
}

public Consent build() {
return new Consent(this);
}
Expand All @@ -119,6 +128,7 @@ private Consent(Builder builder) {
this.clientId = Objects.requireNonNull(builder.clientId);
// This may be null for native apps or single-page apps.
this.clientSecret = builder.clientSecret;
this.browserTimeout = builder.browserTimeout;
}

public Consent setHttpClient(HttpClient hc) {
Expand Down Expand Up @@ -155,6 +165,10 @@ public String getClientSecret() {
return clientSecret;
}

public Optional<Duration> getBrowserTimeout() {
return browserTimeout;
}

/**
* Launch a browser to collect an authorization code and exchange the code for an OAuth token.
*
Expand Down Expand Up @@ -219,7 +233,8 @@ private Map<String, String> getOAuthCallbackParameters() throws IOException {
+ redirect.getHost()
+ ", redirectUrl host must be one of: localhost, 127.0.0.1");
}
CallbackResponseHandler handler = new CallbackResponseHandler();

CallbackResponseHandler handler = new CallbackResponseHandler(this.browserTimeout);
HttpServer httpServer =
HttpServer.create(new InetSocketAddress(redirect.getHost(), redirect.getPort()), 0);
httpServer.createContext("/", handler);
Expand Down Expand Up @@ -282,9 +297,13 @@ private Token exchange(String code, String state) {

static class CallbackResponseHandler implements HttpHandler {
private final Logger LOG = LoggerFactory.getLogger(getClass().getName());
// Protects params
private final Object lock = new Object();
protected final Object lock = new Object(); // protected for testing
private volatile Map<String, String> params;
private final Optional<Duration> timeout;

public CallbackResponseHandler(Optional<Duration> timeout) {
this.timeout = timeout;
}

@Override
public void handle(HttpExchange exchange) {
Expand Down Expand Up @@ -323,10 +342,7 @@ public void handleInner(HttpExchange exchange) throws IOException {
});

sendSuccess(exchange);
synchronized (lock) {
params = theseParams;
lock.notify();
}
setParams(theseParams);
}

private void sendError(
Expand Down Expand Up @@ -373,7 +389,16 @@ public Map<String, String> getParams() {
synchronized (lock) {
if (params == null) {
try {
lock.wait();
if (timeout.isPresent()) {
Duration t = timeout.get();
lock.wait(t.toMillis());
if (params == null) {
throw new DatabricksException(
"OAuth browser authentication timed out after " + t.getSeconds() + " seconds");
}
} else {
lock.wait();
}
} catch (InterruptedException e) {
throw new DatabricksException(
"Interrupted while waiting for parameters: " + e.getMessage(), e);
Expand All @@ -382,5 +407,12 @@ public Map<String, String> getParams() {
return params;
}
}

void setParams(Map<String, String> params) {
synchronized (lock) {
this.params = params;
lock.notify();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ CachedTokenSource performBrowserAuth(
.withHost(config.getHost())
.withAccountId(config.getAccountId())
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
.withBrowserTimeout(config.getOAuthBrowserAuthTimeout())
.withScopes(new ArrayList<>(scopes))
.build();
Consent consent = client.initiateConsent();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.*;
import java.time.Duration;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

/**
Expand All @@ -36,6 +42,7 @@ public static class Builder {
private String clientSecret;
private HttpClient hc;
private String accountId;
private Optional<Duration> browserTimeout = Optional.empty();

public Builder() {}

Expand Down Expand Up @@ -77,6 +84,11 @@ public Builder withAccountId(String accountId) {
this.accountId = accountId;
return this;
}

public Builder withBrowserTimeout(Duration browserTimeout) {
this.browserTimeout = Optional.of(browserTimeout);
return this;
}
}

private final String clientId;
Expand All @@ -90,6 +102,7 @@ public Builder withAccountId(String accountId) {
private final SecureRandom random = new SecureRandom();
private final boolean isAws;
private final boolean isAzure;
private final Optional<Duration> browserTimeout;

private OAuthClient(Builder b) throws IOException {
this.clientId = Objects.requireNonNull(b.clientId);
Expand All @@ -109,6 +122,7 @@ private OAuthClient(Builder b) throws IOException {
this.isAzure = config.isAzure();
this.tokenUrl = oidc.getTokenEndpoint();
this.authUrl = oidc.getAuthorizationEndpoint();
this.browserTimeout = b.browserTimeout;
this.scopes = b.scopes;
}

Expand Down Expand Up @@ -197,6 +211,7 @@ public Consent initiateConsent() throws MalformedURLException {
.withState(state)
.withVerifier(verifier)
.withHttpClient(hc)
.withBrowserTimeout(browserTimeout)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.databricks.sdk.core.oauth.TokenSource;
import com.databricks.sdk.core.utils.Environment;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -250,4 +251,35 @@ public void testGetTokenSourceWithOAuth() {
assertFalse(tokenSource instanceof ErrorTokenSource);
assertEquals(tokenSource.getToken().getAccessToken(), "test-token");
}

@Test
public void testOAuthBrowserAuthTimeout() {
DatabricksConfig config = new DatabricksConfig();

assertNull(config.getOAuthBrowserAuthTimeout());

config.setOAuthBrowserAuthTimeout(Duration.ofSeconds(30));
assertEquals(Duration.ofSeconds(30), config.getOAuthBrowserAuthTimeout());

config.setOAuthBrowserAuthTimeout(Duration.ofSeconds(60));
assertEquals(Duration.ofSeconds(60), config.getOAuthBrowserAuthTimeout());

config.setOAuthBrowserAuthTimeout(Duration.ofSeconds(0));
assertEquals(Duration.ZERO, config.getOAuthBrowserAuthTimeout());
}

@Test
public void testEnvironmentVariableLoading() {
Map<String, String> env = new HashMap<>();
env.put("DATABRICKS_OAUTH_BROWSER_AUTH_TIMEOUT", "30");
env.put("DATABRICKS_DEBUG_TRUNCATE_BYTES", "100");
env.put("DATABRICKS_RATE_LIMIT", "50");

DatabricksConfig config = new DatabricksConfig();
config.resolve(new Environment(env, new ArrayList<>(), System.getProperty("os.name")));

assertEquals(Duration.ofSeconds(30), config.getOAuthBrowserAuthTimeout());
assertEquals(Integer.valueOf(100), config.getDebugTruncateBytes());
assertEquals(Integer.valueOf(50), config.getRateLimit());
}
}
Loading
Loading