Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### New Features and Improvements

* Add option to add a timeout for browser confirmation in the U2M authentication flow.

### Bug Fixes

* User provided scopes are now properly propagated in OAuth flows.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.databricks.sdk.core;

import java.lang.reflect.Field;
import java.time.Duration;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -41,16 +42,25 @@ public void setValueOnConfig(DatabricksConfig cfg, String value) throws IllegalA
// workspace clients or config resolution) are safe
synchronized (field) {
field.setAccessible(true);
if (value == null || value.trim().isEmpty()) {
return;
}

if (field.getType() == String.class) {
field.set(cfg, value);
} else if (field.getType() == int.class) {
field.set(cfg, Integer.parseInt(value));
} else if (field.getType() == Integer.class) {
field.set(cfg, Integer.parseInt(value));
} else if (field.getType() == boolean.class) {
field.set(cfg, Boolean.parseBoolean(value));
} else if (field.getType() == Boolean.class) {
field.set(cfg, Boolean.parseBoolean(value));
} else if (field.getType() == Duration.class) {
int seconds = Integer.parseInt(value);
field.set(cfg, seconds > 0 ? Duration.ofSeconds(seconds) : null);
} else if (field.getType() == ProxyConfig.ProxyAuthType.class) {
if (value != null) {
field.set(cfg, ProxyConfig.ProxyAuthType.valueOf(value));
}
field.set(cfg, ProxyConfig.ProxyAuthType.valueOf(value));
}
field.setAccessible(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.time.Duration;
import java.util.*;
import org.apache.http.HttpMessage;

Expand Down Expand Up @@ -163,6 +164,13 @@ public class DatabricksConfig {
@ConfigAttribute(env = "DATABRICKS_DISABLE_ASYNC_TOKEN_REFRESH")
private Boolean disableAsyncTokenRefresh;

/**
* The duration to wait for a browser response during U2M authentication before timing out. If set
* to 0 or null, the connector waits for an indefinite amount of time.
*/
@ConfigAttribute(env = "DATABRICKS_OAUTH_BROWSER_AUTH_TIMEOUT")
private Duration oauthBrowserAuthTimeout;

public Environment getEnv() {
return env;
}
Expand Down Expand Up @@ -597,6 +605,15 @@ public DatabricksConfig setDisableAsyncTokenRefresh(boolean disableAsyncTokenRef
return this;
}

public Duration getOAuthBrowserAuthTimeout() {
return oauthBrowserAuthTimeout;
}

public DatabricksConfig setOAuthBrowserAuthTimeout(Duration oauthBrowserAuthTimeout) {
this.oauthBrowserAuthTimeout = oauthBrowserAuthTimeout;
return this;
}

public boolean isAzure() {
if (azureWorkspaceResourceId != null) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
import java.io.Serializable;
import java.net.*;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -53,6 +55,7 @@ public class Consent implements Serializable {
private final String redirectUrl;
private final String clientId;
private final String clientSecret;
private final Optional<Duration> browserTimeout;

public static class Builder {
private HttpClient hc = new CommonsHttpClient.Builder().withTimeoutSeconds(30).build();
Expand All @@ -63,6 +66,7 @@ public static class Builder {
private String redirectUrl;
private String clientId;
private String clientSecret;
private Optional<Duration> browserTimeout = Optional.empty();

public Builder withHttpClient(HttpClient hc) {
this.hc = hc;
Expand Down Expand Up @@ -104,6 +108,11 @@ public Builder withClientSecret(String clientSecret) {
return this;
}

public Builder withBrowserTimeout(Optional<Duration> browserTimeout) {
this.browserTimeout = browserTimeout;
return this;
}

public Consent build() {
return new Consent(this);
}
Expand All @@ -119,6 +128,7 @@ private Consent(Builder builder) {
this.clientId = Objects.requireNonNull(builder.clientId);
// This may be null for native apps or single-page apps.
this.clientSecret = builder.clientSecret;
this.browserTimeout = builder.browserTimeout;
}

public Consent setHttpClient(HttpClient hc) {
Expand Down Expand Up @@ -155,6 +165,10 @@ public String getClientSecret() {
return clientSecret;
}

public Optional<Duration> getBrowserTimeout() {
return browserTimeout;
}

/**
* Launch a browser to collect an authorization code and exchange the code for an OAuth token.
*
Expand Down Expand Up @@ -219,15 +233,19 @@ private Map<String, String> getOAuthCallbackParameters() throws IOException {
+ redirect.getHost()
+ ", redirectUrl host must be one of: localhost, 127.0.0.1");
}
CallbackResponseHandler handler = new CallbackResponseHandler();

CallbackResponseHandler handler = new CallbackResponseHandler(this.browserTimeout);
HttpServer httpServer =
HttpServer.create(new InetSocketAddress(redirect.getHost(), redirect.getPort()), 0);
httpServer.createContext("/", handler);
httpServer.start();
desktopBrowser();
Map<String, String> params = handler.getParams();
httpServer.stop(0);
return params;

try {
desktopBrowser();
return handler.getParams();
} finally {
httpServer.stop(0);
}
}

/**
Expand Down Expand Up @@ -282,9 +300,13 @@ private Token exchange(String code, String state) {

static class CallbackResponseHandler implements HttpHandler {
private final Logger LOG = LoggerFactory.getLogger(getClass().getName());
// Protects params
private final Object lock = new Object();
protected final Object lock = new Object(); // protected for testing
private volatile Map<String, String> params;
private final Optional<Duration> timeout;

public CallbackResponseHandler(Optional<Duration> timeout) {
this.timeout = timeout;
}

@Override
public void handle(HttpExchange exchange) {
Expand Down Expand Up @@ -323,10 +345,7 @@ public void handleInner(HttpExchange exchange) throws IOException {
});

sendSuccess(exchange);
synchronized (lock) {
params = theseParams;
lock.notify();
}
setParams(theseParams);
}

private void sendError(
Expand Down Expand Up @@ -369,11 +388,25 @@ private void sendSuccess(HttpExchange exchange) throws IOException {
exchange.close();
}

/**
* Wait and return the params.
*
* <p>This method might throw an exception in case of timeout.
*/
public Map<String, String> getParams() {
synchronized (lock) {
if (params == null) {
try {
lock.wait();
if (timeout.isPresent()) {
Duration t = timeout.get();
lock.wait(t.toMillis());
if (params == null) {
throw new DatabricksException(
"OAuth browser authentication timed out after " + t.getSeconds() + " seconds");
}
} else {
lock.wait();
}
} catch (InterruptedException e) {
throw new DatabricksException(
"Interrupted while waiting for parameters: " + e.getMessage(), e);
Expand All @@ -382,5 +415,12 @@ public Map<String, String> getParams() {
return params;
}
}

void setParams(Map<String, String> params) {
synchronized (lock) {
this.params = params;
lock.notify();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ CachedTokenSource performBrowserAuth(
.withHost(config.getHost())
.withAccountId(config.getAccountId())
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
.withBrowserTimeout(config.getOAuthBrowserAuthTimeout())
.withScopes(new ArrayList<>(scopes))
.build();
Consent consent = client.initiateConsent();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.*;
import java.time.Duration;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

/**
Expand All @@ -36,6 +42,7 @@ public static class Builder {
private String clientSecret;
private HttpClient hc;
private String accountId;
private Optional<Duration> browserTimeout = Optional.empty();

public Builder() {}

Expand Down Expand Up @@ -77,6 +84,11 @@ public Builder withAccountId(String accountId) {
this.accountId = accountId;
return this;
}

public Builder withBrowserTimeout(Duration browserTimeout) {
this.browserTimeout = Optional.of(browserTimeout);
return this;
}
}

private final String clientId;
Expand All @@ -90,6 +102,7 @@ public Builder withAccountId(String accountId) {
private final SecureRandom random = new SecureRandom();
private final boolean isAws;
private final boolean isAzure;
private final Optional<Duration> browserTimeout;

private OAuthClient(Builder b) throws IOException {
this.clientId = Objects.requireNonNull(b.clientId);
Expand All @@ -109,6 +122,7 @@ private OAuthClient(Builder b) throws IOException {
this.isAzure = config.isAzure();
this.tokenUrl = oidc.getTokenEndpoint();
this.authUrl = oidc.getAuthorizationEndpoint();
this.browserTimeout = b.browserTimeout;
this.scopes = b.scopes;
}

Expand Down Expand Up @@ -197,6 +211,7 @@ public Consent initiateConsent() throws MalformedURLException {
.withState(state)
.withVerifier(verifier)
.withHttpClient(hc)
.withBrowserTimeout(browserTimeout)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.databricks.sdk.core.oauth.TokenSource;
import com.databricks.sdk.core.utils.Environment;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -250,4 +251,35 @@ public void testGetTokenSourceWithOAuth() {
assertFalse(tokenSource instanceof ErrorTokenSource);
assertEquals(tokenSource.getToken().getAccessToken(), "test-token");
}

@Test
public void testOAuthBrowserAuthTimeout() {
DatabricksConfig config = new DatabricksConfig();

assertNull(config.getOAuthBrowserAuthTimeout());

config.setOAuthBrowserAuthTimeout(Duration.ofSeconds(30));
assertEquals(Duration.ofSeconds(30), config.getOAuthBrowserAuthTimeout());

config.setOAuthBrowserAuthTimeout(Duration.ofSeconds(60));
assertEquals(Duration.ofSeconds(60), config.getOAuthBrowserAuthTimeout());

config.setOAuthBrowserAuthTimeout(Duration.ofSeconds(0));
assertEquals(Duration.ZERO, config.getOAuthBrowserAuthTimeout());
}

@Test
public void testEnvironmentVariableLoading() {
Map<String, String> env = new HashMap<>();
env.put("DATABRICKS_OAUTH_BROWSER_AUTH_TIMEOUT", "30");
env.put("DATABRICKS_DEBUG_TRUNCATE_BYTES", "100");
env.put("DATABRICKS_RATE_LIMIT", "50");

DatabricksConfig config = new DatabricksConfig();
config.resolve(new Environment(env, new ArrayList<>(), System.getProperty("os.name")));

assertEquals(Duration.ofSeconds(30), config.getOAuthBrowserAuthTimeout());
assertEquals(Integer.valueOf(100), config.getDebugTruncateBytes());
assertEquals(Integer.valueOf(50), config.getRateLimit());
}
}
Loading
Loading