Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Release v0.60.0

### New Features and Improvements
- Azure Service Principal credential provider can now automatically discover tenant ID when not explicitly provided

### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request,
* while automatically resolving different Azure environment endpoints.
*/
public class AzureServicePrincipalCredentialsProvider implements CredentialsProvider {
private static final Logger logger =
LoggerFactory.getLogger(AzureServicePrincipalCredentialsProvider.class);
private final ObjectMapper mapper = new ObjectMapper();
private String tenantId;

@Override
public String authType() {
Expand All @@ -22,12 +27,22 @@ public String authType() {
public OAuthHeaderFactory configure(DatabricksConfig config) {
if (!config.isAzure()
|| config.getAzureClientId() == null
|| config.getAzureClientSecret() == null
|| config.getAzureTenantId() == null) {
|| config.getAzureClientSecret() == null) {
return null;
}
AzureUtils.ensureHostPresent(
config, mapper, AzureServicePrincipalCredentialsProvider::tokenSourceFor);

try {
this.tenantId =
config.getAzureTenantId() != null
? config.getAzureTenantId()
: AzureUtils.inferTenantId(config);
} catch (Exception e) {
logger.warn("Failed to infer Azure tenant ID: {}", e.getMessage());
return null;
}

AzureUtils.ensureHostPresent(config, mapper, this::tokenSourceFor);

CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
CachedTokenSource cloud =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
Expand Down Expand Up @@ -55,9 +70,9 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
* @return A CachedTokenSource instance capable of fetching OAuth tokens for the specified Azure
* resource.
*/
private static CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
private CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint();
String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token";
String tokenUrl = aadEndpoint + this.tenantId + "/oauth2/token";
Map<String, String> endpointParams = new HashMap<>();
endpointParams.put("resource", resource);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,20 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.IOException;
import java.net.URL;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AzureUtils {

private static final Logger logger = LoggerFactory.getLogger(AzureUtils.class);

/** Azure authentication endpoint for tenant ID discovery */
private static final String AZURE_AUTH_ENDPOINT = "/aad/auth";

public static String getWorkspaceFromJsonResponse(ObjectNode jsonResponse) throws IOException {
JsonNode properties = jsonResponse.get("properties");
if (properties == null) {
Expand Down Expand Up @@ -95,4 +103,76 @@ public static Optional<String> getAzureWorkspaceResourceId(Workspace workspace)
workspace.getWorkspaceName());
return Optional.of(resourceId);
}

/**
* Infers the Azure tenant ID from the Databricks workspace login page.
*
* @param config The DatabricksConfig instance
* @return the discovered tenant ID
* @throws DatabricksException if tenant ID discovery fails
*/
public static String inferTenantId(DatabricksConfig config) throws DatabricksException {

if (config.getAzureTenantId() != null) {
return config.getAzureTenantId();
}

if (config.getHost() == null) {
throw new DatabricksException("Cannot infer tenant ID: host is missing");
}

if (!config.isAzure()) {
throw new DatabricksException("Cannot infer tenant ID: workspace is not Azure");
}

String loginUrl = config.getHost() + AZURE_AUTH_ENDPOINT;

try {
String redirectLocation = getRedirectLocation(config, loginUrl);
String extractedTenantId = extractTenantIdFromUrl(redirectLocation);
logger.info("Successfully discovered Azure tenant ID: {}", extractedTenantId);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also remove private static final Logger logger = LoggerFactory.getLogger(AzureUtils.class);

Suggested change
logger.info("Successfully discovered Azure tenant ID: {}", extractedTenantId);

return extractedTenantId;

} catch (Exception e) {
throw new DatabricksException("Failed to infer Azure tenant ID from " + loginUrl, e);
}
}

private static String getRedirectLocation(DatabricksConfig config, String loginUrl)
throws IOException {
Request request = new Request("GET", loginUrl);
request.setRedirectionBehavior(false);
Response response = config.getHttpClient().execute(request);

if (response.getStatusCode() != 302) {
throw new DatabricksException(
"Expected redirect (302) from "
+ loginUrl
+ ", got status code: "
+ response.getStatusCode());
}

String location = response.getFirstHeader("Location");
if (location == null) {
throw new DatabricksException("No Location header in redirect response from " + loginUrl);
}

return location;
}

private static String extractTenantIdFromUrl(String redirectUrl) throws DatabricksException {
try {
// Parse: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
URL entraIdUrl = new URL(redirectUrl);
String[] pathSegments = entraIdUrl.getPath().split("/");

if (pathSegments.length < 2) {
throw new DatabricksException("Invalid path in Location header: " + entraIdUrl.getPath());
}

return pathSegments[1];
} catch (Exception e) {
throw new DatabricksException("Failed to parse tenant ID from URL " + redirectUrl, e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package com.databricks.sdk.core.utils;

import static org.junit.jupiter.api.Assertions.*;

import com.databricks.sdk.core.DatabricksConfig;
import com.databricks.sdk.core.DatabricksException;
import com.databricks.sdk.core.FixtureServer;
import com.databricks.sdk.core.commons.CommonsHttpClient;
import java.io.IOException;
import org.junit.jupiter.api.Test;

public class AzureUtilsTest {

@Test
public void testInferTenantId404() throws IOException {
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 404)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setAzureWorkspaceResourceId(
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());

DatabricksException exception =
assertThrows(
DatabricksException.class,
() -> {
AzureUtils.inferTenantId(config);
});
assertEquals(
"Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth",
exception.getMessage());

assertNotNull(exception.getCause());
assertInstanceOf(DatabricksException.class, exception.getCause());
DatabricksException cause = (DatabricksException) exception.getCause();
assertEquals(
"Expected redirect (302) from " + server.getUrl() + "/aad/auth, got status code: 404",
cause.getMessage());

assertNull(config.getAzureTenantId());
}
}

@Test
public void testInferTenantIdNoLocationHeader() throws IOException {
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 302)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setAzureWorkspaceResourceId(
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());

DatabricksException exception =
assertThrows(
DatabricksException.class,
() -> {
AzureUtils.inferTenantId(config);
});
assertEquals(
"Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth",
exception.getMessage());

assertNotNull(exception.getCause());
assertInstanceOf(DatabricksException.class, exception.getCause());
DatabricksException cause = (DatabricksException) exception.getCause();
assertEquals(
"No Location header in redirect response from " + server.getUrl() + "/aad/auth",
cause.getMessage());

assertNull(config.getAzureTenantId());
}
}

@Test
public void testInferTenantIdUnparsableLocationHeader() throws IOException {
FixtureServer.FixtureMapping fixture =
new FixtureServer.FixtureMapping.Builder()
.validateMethod("GET")
.validatePath("/aad/auth")
.withRedirect("https://unexpected-location", 302)
.build();

try (FixtureServer server = new FixtureServer().with(fixture)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setAzureWorkspaceResourceId(
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());

DatabricksException exception =
assertThrows(
DatabricksException.class,
() -> {
AzureUtils.inferTenantId(config);
});
assertEquals(
"Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth",
exception.getMessage());

assertNotNull(exception.getCause());
assertInstanceOf(DatabricksException.class, exception.getCause());
DatabricksException cause = (DatabricksException) exception.getCause();
assertEquals(
"Failed to parse tenant ID from URL https://unexpected-location", cause.getMessage());

assertNull(config.getAzureTenantId());
}
}

@Test
public void testInferTenantIdHappyPath() throws IOException {
FixtureServer.FixtureMapping fixture =
new FixtureServer.FixtureMapping.Builder()
.validateMethod("GET")
.validatePath("/aad/auth")
.withRedirect("https://login.microsoftonline.com/test-tenant-id/oauth2/authorize", 302)
.build();

try (FixtureServer server = new FixtureServer().with(fixture)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setAzureWorkspaceResourceId(
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
String result = AzureUtils.inferTenantId(config);
assertEquals("test-tenant-id", result);
assertNull(config.getAzureTenantId()); // Config should remain unchanged
}
}

@Test
public void testInferTenantIdSkipsWhenNotAzure() {
DatabricksConfig config = new DatabricksConfig();
config.setHost("https://my-workspace.cloud.databricks.com"); // non-azure host
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());

DatabricksException exception =
assertThrows(
DatabricksException.class,
() -> {
AzureUtils.inferTenantId(config);
});
assertEquals("Cannot infer tenant ID: workspace is not Azure", exception.getMessage());
assertNull(config.getAzureTenantId());
}

@Test
public void testInferTenantIdSkipsWhenAlreadySet() {
DatabricksConfig config = new DatabricksConfig();
config.setHost("https://adb-123.0.azuredatabricks.net");
config.setAzureTenantId("existing-tenant-id");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
String result = AzureUtils.inferTenantId(config);
assertEquals("existing-tenant-id", result);
assertEquals("existing-tenant-id", config.getAzureTenantId()); // Config should remain unchanged
}

@Test
public void testInferTenantIdSkipsWhenNoHost() {
DatabricksConfig config = new DatabricksConfig();
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());

DatabricksException exception =
assertThrows(
DatabricksException.class,
() -> {
AzureUtils.inferTenantId(config);
});
assertEquals("Cannot infer tenant ID: host is missing", exception.getMessage());
assertNull(config.getAzureTenantId());
}
}
Loading