Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Release v0.60.0

### New Features and Improvements
- Azure Service Principal credential provider can now automatically discover tenant ID when not explicitly provided

### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request,
* while automatically resolving different Azure environment endpoints.
*/
public class AzureServicePrincipalCredentialsProvider implements CredentialsProvider {
private static final Logger logger =
LoggerFactory.getLogger(AzureServicePrincipalCredentialsProvider.class);
private final ObjectMapper mapper = new ObjectMapper();
private String tenantId;

@Override
public String authType() {
Expand All @@ -22,12 +27,22 @@ public String authType() {
public OAuthHeaderFactory configure(DatabricksConfig config) {
if (!config.isAzure()
|| config.getAzureClientId() == null
|| config.getAzureClientSecret() == null
|| config.getAzureTenantId() == null) {
|| config.getAzureClientSecret() == null) {
return null;
}
AzureUtils.ensureHostPresent(
config, mapper, AzureServicePrincipalCredentialsProvider::tokenSourceFor);

try {
this.tenantId =
config.getAzureTenantId() != null
? config.getAzureTenantId()
: AzureUtils.inferTenantId(config);
} catch (Exception e) {
logger.warn("Failed to infer Azure tenant ID: {}", e.getMessage());
return null;
}

AzureUtils.ensureHostPresent(config, mapper, this::tokenSourceFor);

CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
CachedTokenSource cloud =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
Expand Down Expand Up @@ -55,9 +70,9 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
* @return A CachedTokenSource instance capable of fetching OAuth tokens for the specified Azure
* resource.
*/
private static CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
private CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint();
String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token";
String tokenUrl = aadEndpoint + this.tenantId + "/oauth2/token";
Map<String, String> endpointParams = new HashMap<>();
endpointParams.put("resource", resource);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.IOException;
import java.net.URL;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;

public class AzureUtils {

/** Azure authentication endpoint for tenant ID discovery */
private static final String AZURE_AUTH_ENDPOINT = "/aad/auth";

public static String getWorkspaceFromJsonResponse(ObjectNode jsonResponse) throws IOException {
JsonNode properties = jsonResponse.get("properties");
if (properties == null) {
Expand Down Expand Up @@ -95,4 +99,74 @@ public static Optional<String> getAzureWorkspaceResourceId(Workspace workspace)
workspace.getWorkspaceName());
return Optional.of(resourceId);
}

/**
* Infers the Azure tenant ID from the Databricks workspace login page.
*
* @param config The DatabricksConfig instance
* @return the discovered tenant ID
* @throws DatabricksException if tenant ID discovery fails
*/
public static String inferTenantId(DatabricksConfig config) throws DatabricksException {

if (config.getAzureTenantId() != null) {
return config.getAzureTenantId();
}

if (config.getHost() == null) {
throw new DatabricksException("Cannot infer tenant ID: host is missing");
}

if (!config.isAzure()) {
throw new DatabricksException("Cannot infer tenant ID: workspace is not Azure");
}

String loginUrl = config.getHost() + AZURE_AUTH_ENDPOINT;

try {
String redirectLocation = getRedirectLocation(config, loginUrl);
return extractTenantIdFromUrl(redirectLocation);

} catch (Exception e) {
throw new DatabricksException("Failed to infer Azure tenant ID from " + loginUrl, e);
}
}

private static String getRedirectLocation(DatabricksConfig config, String loginUrl)
throws IOException {
Request request = new Request("GET", loginUrl);
request.setRedirectionBehavior(false);
Response response = config.getHttpClient().execute(request);

if (response.getStatusCode() != 302) {
throw new DatabricksException(
"Expected redirect (302) from "
+ loginUrl
+ ", got status code: "
+ response.getStatusCode());
}

String location = response.getFirstHeader("Location");
if (location == null) {
throw new DatabricksException("No Location header in redirect response from " + loginUrl);
}

return location;
}

private static String extractTenantIdFromUrl(String redirectUrl) throws DatabricksException {
try {
// Parse: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
URL entraIdUrl = new URL(redirectUrl);
String[] pathSegments = entraIdUrl.getPath().split("/");

if (pathSegments.length < 2) {
throw new DatabricksException("Invalid path in Location header: " + entraIdUrl.getPath());
}

return pathSegments[1];
} catch (Exception e) {
throw new DatabricksException("Failed to parse tenant ID from URL " + redirectUrl, e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package com.databricks.sdk.core.utils;

import static org.junit.jupiter.api.Assertions.*;

import com.databricks.sdk.core.DatabricksConfig;
import com.databricks.sdk.core.DatabricksException;
import com.databricks.sdk.core.FixtureServer;
import com.databricks.sdk.core.commons.CommonsHttpClient;
import java.io.IOException;
import org.junit.jupiter.api.Test;

public class AzureUtilsTest {

@Test
public void testInferTenantId404() throws IOException {
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 404)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setAzureWorkspaceResourceId(
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());

DatabricksException exception =
assertThrows(
DatabricksException.class,
() -> {
AzureUtils.inferTenantId(config);
});
assertEquals(
"Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth",
exception.getMessage());

assertNotNull(exception.getCause());
assertInstanceOf(DatabricksException.class, exception.getCause());
DatabricksException cause = (DatabricksException) exception.getCause();
assertEquals(
"Expected redirect (302) from " + server.getUrl() + "/aad/auth, got status code: 404",
cause.getMessage());

assertNull(config.getAzureTenantId());
}
}

@Test
public void testInferTenantIdNoLocationHeader() throws IOException {
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 302)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setAzureWorkspaceResourceId(
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());

DatabricksException exception =
assertThrows(
DatabricksException.class,
() -> {
AzureUtils.inferTenantId(config);
});
assertEquals(
"Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth",
exception.getMessage());

assertNotNull(exception.getCause());
assertInstanceOf(DatabricksException.class, exception.getCause());
DatabricksException cause = (DatabricksException) exception.getCause();
assertEquals(
"No Location header in redirect response from " + server.getUrl() + "/aad/auth",
cause.getMessage());

assertNull(config.getAzureTenantId());
}
}

@Test
public void testInferTenantIdUnparsableLocationHeader() throws IOException {
FixtureServer.FixtureMapping fixture =
new FixtureServer.FixtureMapping.Builder()
.validateMethod("GET")
.validatePath("/aad/auth")
.withRedirect("https://unexpected-location", 302)
.build();

try (FixtureServer server = new FixtureServer().with(fixture)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setAzureWorkspaceResourceId(
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());

DatabricksException exception =
assertThrows(
DatabricksException.class,
() -> {
AzureUtils.inferTenantId(config);
});
assertEquals(
"Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth",
exception.getMessage());

assertNotNull(exception.getCause());
assertInstanceOf(DatabricksException.class, exception.getCause());
DatabricksException cause = (DatabricksException) exception.getCause();
assertEquals(
"Failed to parse tenant ID from URL https://unexpected-location", cause.getMessage());

assertNull(config.getAzureTenantId());
}
}

@Test
public void testInferTenantIdHappyPath() throws IOException {
FixtureServer.FixtureMapping fixture =
new FixtureServer.FixtureMapping.Builder()
.validateMethod("GET")
.validatePath("/aad/auth")
.withRedirect("https://login.microsoftonline.com/test-tenant-id/oauth2/authorize", 302)
.build();

try (FixtureServer server = new FixtureServer().with(fixture)) {
DatabricksConfig config = new DatabricksConfig();
config.setHost(server.getUrl());
config.setAzureWorkspaceResourceId(
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
String result = AzureUtils.inferTenantId(config);
assertEquals("test-tenant-id", result);
assertNull(config.getAzureTenantId()); // Config should remain unchanged
}
}

@Test
public void testInferTenantIdSkipsWhenNotAzure() {
DatabricksConfig config = new DatabricksConfig();
config.setHost("https://my-workspace.cloud.databricks.com"); // non-azure host
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());

DatabricksException exception =
assertThrows(
DatabricksException.class,
() -> {
AzureUtils.inferTenantId(config);
});
assertEquals("Cannot infer tenant ID: workspace is not Azure", exception.getMessage());
assertNull(config.getAzureTenantId());
}

@Test
public void testInferTenantIdSkipsWhenAlreadySet() {
DatabricksConfig config = new DatabricksConfig();
config.setHost("https://adb-123.0.azuredatabricks.net");
config.setAzureTenantId("existing-tenant-id");
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
String result = AzureUtils.inferTenantId(config);
assertEquals("existing-tenant-id", result);
assertEquals("existing-tenant-id", config.getAzureTenantId()); // Config should remain unchanged
}

@Test
public void testInferTenantIdSkipsWhenNoHost() {
DatabricksConfig config = new DatabricksConfig();
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());

DatabricksException exception =
assertThrows(
DatabricksException.class,
() -> {
AzureUtils.inferTenantId(config);
});
assertEquals("Cannot infer tenant ID: host is missing", exception.getMessage());
assertNull(config.getAzureTenantId());
}
}
Loading