Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Release v0.57.0

### New Features and Improvements
- Azure Service Principal credential provider can now automatically discover tenant ID when not explicitly provided

### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,19 @@
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
import java.util.*;
import org.apache.http.HttpMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DatabricksConfig {

private static final Logger logger = LoggerFactory.getLogger(DatabricksConfig.class);

/** Azure authentication endpoint for tenant ID discovery */
private static final String AZURE_AUTH_ENDPOINT = "/aad/auth";

private CredentialsProvider credentialsProvider = new DefaultCredentialsProvider();

@ConfigAttribute(env = "DATABRICKS_HOST")
Expand Down Expand Up @@ -726,7 +735,7 @@ private DatabricksConfig clone(Set<String> fieldsToSkip) {
}

public DatabricksConfig clone() {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to clone this method for all our config classes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate on this comment? Are you asking whether we need to implement similar clone() methods in other configuration classes, or are you suggesting a different architectural approach for handling object cloning?

return clone(new HashSet<>());
return clone(new HashSet<>(Arrays.asList("logger", "AZURE_AUTH_ENDPOINT")));
}

public DatabricksConfig newWithWorkspaceHost(String host) {
Expand All @@ -736,6 +745,8 @@ public DatabricksConfig newWithWorkspaceHost(String host) {
// The config for WorkspaceClient has a different host and Azure Workspace resource
// ID, and also omits
// the account ID.
"logger",
"AZURE_AUTH_ENDPOINT",
"host",
"accountId",
"azureWorkspaceResourceId",
Expand All @@ -755,4 +766,88 @@ public DatabricksConfig newWithWorkspaceHost(String host) {
public String getEffectiveOAuthRedirectUrl() {
return redirectUrl != null ? redirectUrl : "http://localhost:8080/callback";
}

/**
* [Internal] Load the Azure tenant ID from the Azure Databricks login page. If the tenant ID is
* already set, this method does nothing.
*
* @return true if tenant ID is available (either was already set or successfully loaded), false otherwise
*/
public boolean loadAzureTenantId() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that there's a lot of exceptions in the current code but I would recommend treating the config as an immutable object. That is, do not change the tenantId in the config and rather have a mutable copy in the AzureServicePrincipalCredentialsProvider.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, I have updated the code to not make any changes to the state of DatabricksConfig


if (azureTenantId != null) {
return true; // Tenant ID already available - success
}

if (!isAzure() || host == null) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not the isAzure check be the first check in this method & return false even if azureTenantId is having some value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this method is to load the tenant ID, so if it's already set (explicitly by the user), we should respect that and return true. The isAzure() check is only needed when we need to discover the tenant ID from the workspace host. If the tenant ID is set and the host is not Azure, it should get caught at appropriate layers - this method should not be responsible for returning false and blocking the flow.

return false; // Configuration issue - can't perform operation
}

String loginUrl = host + AZURE_AUTH_ENDPOINT;
logger.debug("Loading tenant ID from {}", loginUrl);

try {
String redirectLocation = getRedirectLocation(loginUrl);
if (redirectLocation == null) {
return false; // Failed to get redirect location
}

String extractedTenantId = extractTenantIdFromUrl(redirectLocation);
if (extractedTenantId == null) {
return false; // Failed to extract tenant ID
}

this.azureTenantId = extractedTenantId;
logger.debug("Loaded tenant ID: {}", this.azureTenantId);
return true; // Successfully loaded

} catch (Exception e) {
logger.warn("Failed to load tenant ID: {}", e.getMessage());
return false;
}
}

private String getRedirectLocation(String loginUrl) throws IOException {

Request request = new Request("GET", loginUrl);
request.setRedirectionBehavior(false);
Response response = getHttpClient().execute(request);
int statusCode = response.getStatusCode();

if (statusCode != 302) {
logger.warn(
"Failed to get tenant ID from {}: expected status code 302, got {}",
loginUrl,
statusCode);
return null;
}

String location = response.getFirstHeader("Location");
if (location == null) {
logger.warn("No Location header in response from {}", loginUrl);
}

return location;
}

private String extractTenantIdFromUrl(String redirectUrl) {
try {
// The Location header has the following form:
// https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
// The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US
// Government cloud).
URL entraIdUrl = new URL(redirectUrl);
String[] pathSegments = entraIdUrl.getPath().split("/");

if (pathSegments.length < 2) {
logger.warn("Invalid path in Location header: {}", entraIdUrl.getPath());
return null;
}

return pathSegments[1];
} catch (Exception e) {
logger.warn("Failed to extract tenant ID from URL {}: {}", redirectUrl, e.getMessage());
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,17 @@ public String authType() {
public OAuthHeaderFactory configure(DatabricksConfig config) {
if (!config.isAzure()
|| config.getAzureClientId() == null
|| config.getAzureClientSecret() == null
|| config.getAzureTenantId() == null) {
|| config.getAzureClientSecret() == null) {
return null;
}
AzureUtils.ensureHostPresent(
config, mapper, AzureServicePrincipalCredentialsProvider::tokenSourceFor);

boolean tenantIdLoaded = config.loadAzureTenantId();
if (!tenantIdLoaded) {
return null;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in another comment, I would recommend implementing the loadAzureTenantId as an immutable function that would return the tenantId or fail. The tenantId can be stored in this class.

Something like this:

try {
  this.tenantId = inferTenantId(config)
} catch (Exception e) {
  logger.debug("Failed to extract tenant ID: {}", e.getMessage())
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the code as per your suggestions:

  • Not changing the state of DatabricksConfig
  • Maintaining local variable for tenantId in AzureServicePrincipalCredentialsProvider
  • Moved the tenant id inferring logic to AzureUtils for reusability


CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
CachedTokenSource cloud =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,98 @@ public void testGetTokenSourceWithOAuth() {
assertFalse(tokenSource instanceof ErrorTokenSource);
assertEquals(tokenSource.getToken().getAccessToken(), "test-token");
}

@Test
public void testLoadAzureTenantId404() throws IOException {
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 404)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
boolean result = config.loadAzureTenantId();
assertFalse(result);
assertNull(config.getAzureTenantId());
}
}

@Test
public void testLoadAzureTenantIdNoLocationHeader() throws IOException {
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 302)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
boolean result = config.loadAzureTenantId();
assertFalse(result);
assertNull(config.getAzureTenantId());
}
}

@Test
public void testLoadAzureTenantIdUnparsableLocationHeader() throws IOException {
FixtureServer.FixtureMapping fixture =
new FixtureServer.FixtureMapping.Builder()
.validateMethod("GET")
.validatePath("/aad/auth")
.withRedirect("https://unexpected-location", 302)
.build();

try (FixtureServer server = new FixtureServer().with(fixture)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
boolean result = config.loadAzureTenantId();
assertFalse(result);
assertNull(config.getAzureTenantId());
}
}

@Test
public void testLoadAzureTenantIdHappyPath() throws IOException {
FixtureServer.FixtureMapping fixture =
new FixtureServer.FixtureMapping.Builder()
.validateMethod("GET")
.validatePath("/aad/auth")
.withRedirect("https://login.microsoftonline.com/test-tenant-id/oauth2/authorize", 302)
.build();

try (FixtureServer server = new FixtureServer().with(fixture)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setAzureWorkspaceResourceId(
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
boolean result = config.loadAzureTenantId();
assertTrue(result);
assertEquals("test-tenant-id", config.getAzureTenantId());
}
}

@Test
public void testLoadAzureTenantIdSkipsWhenNotAzure() throws IOException {
DatabricksConfig config = new DatabricksConfig();
config.setHost("https://my-workspace.cloud.databricks.com"); // non-azure host
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
boolean result = config.loadAzureTenantId();
assertFalse(result);
assertNull(config.getAzureTenantId());
}

@Test
public void testLoadAzureTenantIdSkipsWhenAlreadySet() throws IOException {
DatabricksConfig config = new DatabricksConfig();
config.setHost("https://adb-123.0.azuredatabricks.net");
config.setAzureTenantId("existing-tenant-id");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
boolean result = config.loadAzureTenantId();
assertTrue(result);
assertEquals("existing-tenant-id", config.getAzureTenantId());
}

@Test
public void testLoadAzureTenantIdSkipsWhenNoHost() throws IOException {
DatabricksConfig config = new DatabricksConfig();
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
boolean result = config.loadAzureTenantId();
assertFalse(result);
assertNull(config.getAzureTenantId());
}
}
Loading