diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index b60768575..20502a976 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,8 @@ ### New Features and Improvements +* Add option to add a timeout for browser confirmation in the U2M authentication flow. + ### Bug Fixes * User provided scopes are now properly propagated in OAuth flows. diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/ConfigAttributeAccessor.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/ConfigAttributeAccessor.java index 73cb3cba2..06502b508 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/ConfigAttributeAccessor.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/ConfigAttributeAccessor.java @@ -1,6 +1,7 @@ package com.databricks.sdk.core; import java.lang.reflect.Field; +import java.time.Duration; import java.util.Map; import java.util.Objects; @@ -41,16 +42,25 @@ public void setValueOnConfig(DatabricksConfig cfg, String value) throws IllegalA // workspace clients or config resolution) are safe synchronized (field) { field.setAccessible(true); + if (value == null || value.trim().isEmpty()) { + return; + } + if (field.getType() == String.class) { field.set(cfg, value); } else if (field.getType() == int.class) { field.set(cfg, Integer.parseInt(value)); + } else if (field.getType() == Integer.class) { + field.set(cfg, Integer.parseInt(value)); } else if (field.getType() == boolean.class) { field.set(cfg, Boolean.parseBoolean(value)); + } else if (field.getType() == Boolean.class) { + field.set(cfg, Boolean.parseBoolean(value)); + } else if (field.getType() == Duration.class) { + int seconds = Integer.parseInt(value); + field.set(cfg, seconds > 0 ? Duration.ofSeconds(seconds) : null); } else if (field.getType() == ProxyConfig.ProxyAuthType.class) { - if (value != null) { - field.set(cfg, ProxyConfig.ProxyAuthType.valueOf(value)); - } + field.set(cfg, ProxyConfig.ProxyAuthType.valueOf(value)); } field.setAccessible(false); } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index 9d6af5ada..074e97974 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -14,6 +14,7 @@ import java.io.File; import java.io.IOException; import java.lang.reflect.Field; +import java.time.Duration; import java.util.*; import org.apache.http.HttpMessage; @@ -163,6 +164,13 @@ public class DatabricksConfig { @ConfigAttribute(env = "DATABRICKS_DISABLE_ASYNC_TOKEN_REFRESH") private Boolean disableAsyncTokenRefresh; + /** + * The duration to wait for a browser response during U2M authentication before timing out. If set + * to 0 or null, the connector waits for an indefinite amount of time. + */ + @ConfigAttribute(env = "DATABRICKS_OAUTH_BROWSER_AUTH_TIMEOUT") + private Duration oauthBrowserAuthTimeout; + public Environment getEnv() { return env; } @@ -597,6 +605,15 @@ public DatabricksConfig setDisableAsyncTokenRefresh(boolean disableAsyncTokenRef return this; } + public Duration getOAuthBrowserAuthTimeout() { + return oauthBrowserAuthTimeout; + } + + public DatabricksConfig setOAuthBrowserAuthTimeout(Duration oauthBrowserAuthTimeout) { + this.oauthBrowserAuthTimeout = oauthBrowserAuthTimeout; + return this; + } + public boolean isAzure() { if (azureWorkspaceResourceId != null) { return true; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java index aee9fe50f..19619d127 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java @@ -14,10 +14,12 @@ import java.io.Serializable; import java.net.*; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.Optional; import org.apache.commons.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -53,6 +55,7 @@ public class Consent implements Serializable { private final String redirectUrl; private final String clientId; private final String clientSecret; + private final Optional browserTimeout; public static class Builder { private HttpClient hc = new CommonsHttpClient.Builder().withTimeoutSeconds(30).build(); @@ -63,6 +66,7 @@ public static class Builder { private String redirectUrl; private String clientId; private String clientSecret; + private Optional browserTimeout = Optional.empty(); public Builder withHttpClient(HttpClient hc) { this.hc = hc; @@ -104,6 +108,11 @@ public Builder withClientSecret(String clientSecret) { return this; } + public Builder withBrowserTimeout(Optional browserTimeout) { + this.browserTimeout = browserTimeout; + return this; + } + public Consent build() { return new Consent(this); } @@ -119,6 +128,7 @@ private Consent(Builder builder) { this.clientId = Objects.requireNonNull(builder.clientId); // This may be null for native apps or single-page apps. this.clientSecret = builder.clientSecret; + this.browserTimeout = builder.browserTimeout; } public Consent setHttpClient(HttpClient hc) { @@ -155,6 +165,10 @@ public String getClientSecret() { return clientSecret; } + public Optional getBrowserTimeout() { + return browserTimeout; + } + /** * Launch a browser to collect an authorization code and exchange the code for an OAuth token. * @@ -219,15 +233,19 @@ private Map getOAuthCallbackParameters() throws IOException { + redirect.getHost() + ", redirectUrl host must be one of: localhost, 127.0.0.1"); } - CallbackResponseHandler handler = new CallbackResponseHandler(); + + CallbackResponseHandler handler = new CallbackResponseHandler(this.browserTimeout); HttpServer httpServer = HttpServer.create(new InetSocketAddress(redirect.getHost(), redirect.getPort()), 0); httpServer.createContext("/", handler); httpServer.start(); - desktopBrowser(); - Map params = handler.getParams(); - httpServer.stop(0); - return params; + + try { + desktopBrowser(); + return handler.getParams(); + } finally { + httpServer.stop(0); + } } /** @@ -282,9 +300,13 @@ private Token exchange(String code, String state) { static class CallbackResponseHandler implements HttpHandler { private final Logger LOG = LoggerFactory.getLogger(getClass().getName()); - // Protects params - private final Object lock = new Object(); + protected final Object lock = new Object(); // protected for testing private volatile Map params; + private final Optional timeout; + + public CallbackResponseHandler(Optional timeout) { + this.timeout = timeout; + } @Override public void handle(HttpExchange exchange) { @@ -323,10 +345,7 @@ public void handleInner(HttpExchange exchange) throws IOException { }); sendSuccess(exchange); - synchronized (lock) { - params = theseParams; - lock.notify(); - } + setParams(theseParams); } private void sendError( @@ -369,11 +388,25 @@ private void sendSuccess(HttpExchange exchange) throws IOException { exchange.close(); } + /** + * Wait and return the params. + * + *

This method might throw an exception in case of timeout. + */ public Map getParams() { synchronized (lock) { if (params == null) { try { - lock.wait(); + if (timeout.isPresent()) { + Duration t = timeout.get(); + lock.wait(t.toMillis()); + if (params == null) { + throw new DatabricksException( + "OAuth browser authentication timed out after " + t.getSeconds() + " seconds"); + } + } else { + lock.wait(); + } } catch (InterruptedException e) { throw new DatabricksException( "Interrupted while waiting for parameters: " + e.getMessage(), e); @@ -382,5 +415,12 @@ public Map getParams() { return params; } } + + void setParams(Map params) { + synchronized (lock) { + this.params = params; + lock.notify(); + } + } } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java index 019dbf6bc..95780b6fa 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java @@ -128,6 +128,7 @@ CachedTokenSource performBrowserAuth( .withHost(config.getHost()) .withAccountId(config.getAccountId()) .withRedirectUrl(config.getEffectiveOAuthRedirectUrl()) + .withBrowserTimeout(config.getOAuthBrowserAuthTimeout()) .withScopes(new ArrayList<>(scopes)) .build(); Consent consent = client.initiateConsent(); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClient.java index 25df0c0a2..3803bafb0 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClient.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClient.java @@ -9,7 +9,13 @@ import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; -import java.util.*; +import java.time.Duration; +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; /** @@ -36,6 +42,7 @@ public static class Builder { private String clientSecret; private HttpClient hc; private String accountId; + private Optional browserTimeout = Optional.empty(); public Builder() {} @@ -77,6 +84,11 @@ public Builder withAccountId(String accountId) { this.accountId = accountId; return this; } + + public Builder withBrowserTimeout(Duration browserTimeout) { + this.browserTimeout = Optional.of(browserTimeout); + return this; + } } private final String clientId; @@ -90,6 +102,7 @@ public Builder withAccountId(String accountId) { private final SecureRandom random = new SecureRandom(); private final boolean isAws; private final boolean isAzure; + private final Optional browserTimeout; private OAuthClient(Builder b) throws IOException { this.clientId = Objects.requireNonNull(b.clientId); @@ -109,6 +122,7 @@ private OAuthClient(Builder b) throws IOException { this.isAzure = config.isAzure(); this.tokenUrl = oidc.getTokenEndpoint(); this.authUrl = oidc.getAuthorizationEndpoint(); + this.browserTimeout = b.browserTimeout; this.scopes = b.scopes; } @@ -197,6 +211,7 @@ public Consent initiateConsent() throws MalformedURLException { .withState(state) .withVerifier(verifier) .withHttpClient(hc) + .withBrowserTimeout(browserTimeout) .build(); } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java index b3ac333a3..88a466a32 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java @@ -12,6 +12,7 @@ import com.databricks.sdk.core.oauth.TokenSource; import com.databricks.sdk.core.utils.Environment; import java.io.IOException; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -250,4 +251,35 @@ public void testGetTokenSourceWithOAuth() { assertFalse(tokenSource instanceof ErrorTokenSource); assertEquals(tokenSource.getToken().getAccessToken(), "test-token"); } + + @Test + public void testOAuthBrowserAuthTimeout() { + DatabricksConfig config = new DatabricksConfig(); + + assertNull(config.getOAuthBrowserAuthTimeout()); + + config.setOAuthBrowserAuthTimeout(Duration.ofSeconds(30)); + assertEquals(Duration.ofSeconds(30), config.getOAuthBrowserAuthTimeout()); + + config.setOAuthBrowserAuthTimeout(Duration.ofSeconds(60)); + assertEquals(Duration.ofSeconds(60), config.getOAuthBrowserAuthTimeout()); + + config.setOAuthBrowserAuthTimeout(Duration.ofSeconds(0)); + assertEquals(Duration.ZERO, config.getOAuthBrowserAuthTimeout()); + } + + @Test + public void testEnvironmentVariableLoading() { + Map env = new HashMap<>(); + env.put("DATABRICKS_OAUTH_BROWSER_AUTH_TIMEOUT", "30"); + env.put("DATABRICKS_DEBUG_TRUNCATE_BYTES", "100"); + env.put("DATABRICKS_RATE_LIMIT", "50"); + + DatabricksConfig config = new DatabricksConfig(); + config.resolve(new Environment(env, new ArrayList<>(), System.getProperty("os.name"))); + + assertEquals(Duration.ofSeconds(30), config.getOAuthBrowserAuthTimeout()); + assertEquals(Integer.valueOf(100), config.getDebugTruncateBytes()); + assertEquals(Integer.valueOf(50), config.getRateLimit()); + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ConsentTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ConsentTest.java new file mode 100644 index 000000000..8ea7a8d3d --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ConsentTest.java @@ -0,0 +1,132 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; + +import com.databricks.sdk.core.DatabricksException; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; + +public class ConsentTest { + + @Test + public void testConsentWithBrowserAuthTimeout() { + Consent consent = + new Consent.Builder() + .withClientId("test-client-id") + .withClientSecret("test-client-secret") + .withAuthUrl("https://test.com/auth") + .withTokenUrl("https://test.com/token") + .withRedirectUrl("http://localhost:8080/callback") + .withState("test-state") + .withVerifier("test-verifier") + .withBrowserTimeout(Optional.of(Duration.ofSeconds(30))) + .build(); + + assertEquals(Optional.of(Duration.ofSeconds(30)), consent.getBrowserTimeout()); + } + + @Test + public void testConsentWithoutBrowserAuthTimeout() { + Consent consent = + new Consent.Builder() + .withClientId("test-client-id") + .withClientSecret("test-client-secret") + .withAuthUrl("https://test.com/auth") + .withTokenUrl("https://test.com/token") + .withRedirectUrl("http://localhost:8080/callback") + .withState("test-state") + .withVerifier("test-verifier") + .build(); + + assertEquals(Optional.empty(), consent.getBrowserTimeout()); + } + + @Test + public void testTimeoutLogicWithShortTimeout() throws InterruptedException { + // Test that timeout is enforced correctly. + Consent.CallbackResponseHandler handler = + new Consent.CallbackResponseHandler(Optional.of(Duration.ofMillis(100))); // 100ms timeout + + long startTime = System.currentTimeMillis(); + + try { + handler.getParams(); + fail("Expected timeout exception"); + } catch (DatabricksException e) { + long elapsedTime = System.currentTimeMillis() - startTime; + assertTrue( + elapsedTime >= 100, "Timeout should have taken at least 100ms, but took " + elapsedTime); + assertTrue(e.getMessage().contains("timed out after 0 seconds")); + } + } + + @Test + public void testTimeoutLogicWithNoTimeout() throws InterruptedException { + // Test that no timeout means indefinite wait. + Consent.CallbackResponseHandler handler = new Consent.CallbackResponseHandler(Optional.empty()); + + CountDownLatch latch = new CountDownLatch(1); + + Thread setterThread = + new Thread( + () -> { + try { + Thread.sleep(50); + synchronized (handler.lock) { + Map params = new HashMap<>(); + params.put("code", "test-code"); + params.put("state", "test-state"); + handler.setParams(params); + } + latch.countDown(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + setterThread.start(); + + Map result = handler.getParams(); + assertNotNull(result); + assertEquals("test-code", result.get("code")); + assertEquals("test-state", result.get("state")); + assertTrue(latch.await(1, TimeUnit.SECONDS)); + } + + @Test + public void testTimeoutLogicWithParamsSetBeforeTimeout() throws InterruptedException { + // Test that if params are set before timeout, no exception is thrown. + Consent.CallbackResponseHandler handler = + new Consent.CallbackResponseHandler(Optional.of(Duration.ofSeconds(1))); + + CountDownLatch latch = new CountDownLatch(1); + + Thread setterThread = + new Thread( + () -> { + try { + Thread.sleep(50); + synchronized (handler.lock) { + Map params = new HashMap<>(); + params.put("code", "test-code"); + handler.setParams(params); + } + latch.countDown(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + setterThread.start(); + + Map result = handler.getParams(); + assertNotNull(result); + assertEquals("test-code", result.get("code")); + assertTrue(latch.await(1, TimeUnit.SECONDS)); + } +}