Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Release v0.57.0

### New Features and Improvements
- Azure Service Principal credential provider can now automatically discover tenant ID when not explicitly provided

### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
import java.util.*;
import org.apache.http.HttpMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DatabricksConfig {

private static final Logger logger = LoggerFactory.getLogger(DatabricksConfig.class);
private CredentialsProvider credentialsProvider = new DefaultCredentialsProvider();

@ConfigAttribute(env = "DATABRICKS_HOST")
Expand Down Expand Up @@ -726,7 +731,7 @@ private DatabricksConfig clone(Set<String> fieldsToSkip) {
}

public DatabricksConfig clone() {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to clone this method for all our config classes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate on this comment? Are you asking whether we need to implement similar clone() methods in other configuration classes, or are you suggesting a different architectural approach for handling object cloning?

return clone(new HashSet<>());
return clone(new HashSet<>(Collections.singletonList("logger")));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Contributor Author

@sreekanth-db sreekanth-db Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logger field is excluded from cloning because it's a static final field. The clone(Set<String> fieldsToSkip) method uses reflection to copy field values with f.set(newConfig, f.get(this)).

Attempting to clone a static field causes an IllegalAccessException with the message: "Can not set static final org.slf4j.Logger field com.databricks.sdk.core.DatabricksConfig.logger to org.slf4j.reload4j.Reload4jLoggerAdapter"

}

public DatabricksConfig newWithWorkspaceHost(String host) {
Expand All @@ -736,6 +741,7 @@ public DatabricksConfig newWithWorkspaceHost(String host) {
// The config for WorkspaceClient has a different host and Azure Workspace resource
// ID, and also omits
// the account ID.
"logger",
"host",
"accountId",
"azureWorkspaceResourceId",
Expand All @@ -755,4 +761,81 @@ public DatabricksConfig newWithWorkspaceHost(String host) {
public String getEffectiveOAuthRedirectUrl() {
return redirectUrl != null ? redirectUrl : "http://localhost:8080/callback";
}

/**
* [Internal] Load the Azure tenant ID from the Azure Databricks login page. If the tenant ID is
* already set, this method does nothing.
*/
public void loadAzureTenantId() {

if (!isAzure() || azureTenantId != null || host == null) {
return;
}

final String azureAuthEndpoint = "/aad/auth";
String loginUrl = host + azureAuthEndpoint;
logger.debug("Loading tenant ID from {}", loginUrl);

try {
String redirectLocation = getRedirectLocation(loginUrl);
if (redirectLocation == null) {
return;
}

String extractedTenantId = extractTenantIdFromUrl(redirectLocation);
if (extractedTenantId == null) {
return;
}

this.azureTenantId = extractedTenantId;
logger.debug("Loaded tenant ID: {}", this.azureTenantId);

} catch (Exception e) {
logger.warn("Failed to load tenant ID: {}", e.getMessage());
}
}

private String getRedirectLocation(String loginUrl) throws IOException {

Request request = new Request("GET", loginUrl);
request.setRedirectionBehavior(false);
Response response = getHttpClient().execute(request);
int statusCode = response.getStatusCode();

if (statusCode / 100 != 3) {
logger.warn(
"Failed to get tenant ID from {}: expected status code 3xx, got {}",
loginUrl,
statusCode);
return null;
}

String location = response.getFirstHeader("Location");
if (location == null) {
logger.warn("No Location header in response from {}", loginUrl);
}

return location;
}

private String extractTenantIdFromUrl(String redirectUrl) {
try {
// The Location header has the following form:
// https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
// The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US
// Government cloud).
URL entraIdUrl = new URL(redirectUrl);
String[] pathSegments = entraIdUrl.getPath().split("/");

if (pathSegments.length < 2) {
logger.warn("Invalid path in Location header: {}", entraIdUrl.getPath());
return null;
}

return pathSegments[1];
} catch (Exception e) {
logger.warn("Failed to extract tenant ID from URL {}: {}", redirectUrl, e.getMessage());
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ public String authType() {
public OAuthHeaderFactory configure(DatabricksConfig config) {
if (!config.isAzure()
|| config.getAzureClientId() == null
|| config.getAzureClientSecret() == null
|| config.getAzureTenantId() == null) {
|| config.getAzureClientSecret() == null) {
return null;
}
AzureUtils.ensureHostPresent(
config, mapper, AzureServicePrincipalCredentialsProvider::tokenSourceFor);
config.loadAzureTenantId();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you return null if azure tenant id fails? to be consistent with line 26?
Also, can we keep the azureSPCredentials specific helpers for extracting tenant id here in credentialsProvider itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have updated the code to return null if loading tenant id fails.

I think we should keep loadAzureTenantId() in DatabricksConfig for reusability. Looking at the Python SDK's implementation, load_azure_tenant_id() is called from multiple Azure credential providers (azure_service_principal and azure_cli), which demonstrates it's designed as a shared utility in the main Config class. This approach ensures consistency across our SDKs and allows for future reuse by other Azure authentication methods.

CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
CachedTokenSource cloud =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,91 @@ public void testGetTokenSourceWithOAuth() {
assertFalse(tokenSource instanceof ErrorTokenSource);
assertEquals(tokenSource.getToken().getAccessToken(), "test-token");
}

@Test
public void testLoadAzureTenantId404() throws IOException {
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 404)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
config.loadAzureTenantId();
assertNull(config.getAzureTenantId());
}
}

@Test
public void testLoadAzureTenantIdNoLocationHeader() throws IOException {
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 302)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
config.loadAzureTenantId();
assertNull(config.getAzureTenantId());
}
}

@Test
public void testLoadAzureTenantIdUnparsableLocationHeader() throws IOException {
FixtureServer.FixtureMapping fixture =
new FixtureServer.FixtureMapping.Builder()
.validateMethod("GET")
.validatePath("/aad/auth")
.withRedirect("https://unexpected-location", 302)
.build();

try (FixtureServer server = new FixtureServer().with(fixture)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
config.loadAzureTenantId();
assertNull(config.getAzureTenantId());
}
}

@Test
public void testLoadAzureTenantIdHappyPath() throws IOException {
FixtureServer.FixtureMapping fixture =
new FixtureServer.FixtureMapping.Builder()
.validateMethod("GET")
.validatePath("/aad/auth")
.withRedirect("https://login.microsoftonline.com/test-tenant-id/oauth2/authorize", 302)
.build();

try (FixtureServer server = new FixtureServer().with(fixture)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setAzureWorkspaceResourceId(
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
config.loadAzureTenantId();
assertEquals("test-tenant-id", config.getAzureTenantId());
}
}

@Test
public void testLoadAzureTenantIdSkipsWhenNotAzure() throws IOException {
DatabricksConfig config = new DatabricksConfig();
config.setHost("https://my-workspace.cloud.databricks.com"); // non-azure host
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
config.loadAzureTenantId();
assertNull(config.getAzureTenantId());
}

@Test
public void testLoadAzureTenantIdSkipsWhenAlreadySet() throws IOException {
DatabricksConfig config = new DatabricksConfig();
config.setHost("https://adb-123.0.azuredatabricks.net");
config.setAzureTenantId("existing-tenant-id");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
config.loadAzureTenantId();
assertEquals("existing-tenant-id", config.getAzureTenantId());
}

@Test
public void testLoadAzureTenantIdSkipsWhenNoHost() throws IOException {
DatabricksConfig config = new DatabricksConfig();
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
config.loadAzureTenantId();
assertNull(config.getAzureTenantId());
}
}
Loading