Skip to content

Commit fe28261

Browse files
Infer azure tenant ID. (#482)
## What changes are proposed in this pull request? This PR modifies Azure service principal credential provider to attempt to load the tenant ID of the workspace if not provided before authenticating. Tenant ID is indirectly exposed via the redirect URL used when logging into a workspace. In this PR, we fetch the tenant ID from this endpoint and configure it if not already set. Reference PR: databricks/databricks-sdk-py#638 **Key changes:** - **Added `inferTenantId()` method** in `AzureUtils` that makes an HTTP request to `{host}/aad/auth` and extracts the tenant ID from the redirect URL - **Modified `AzureServicePrincipalCredentialsProvider`** to remove the explicit `azureTenantId` requirement and automatically call tenant ID discovery - **Added comprehensive unit tests** covering success scenarios, error handling, and edge cases **Technical implementation:** - Makes HTTP GET request to `https://<workspace-host>/aad/auth` endpoint - Follows redirect chain to extract tenant ID from URLs like `https://login.microsoftonline.com/{tenant-id}/oauth2/authorize` - Handles various error scenarios gracefully (404, missing headers, malformed URLs) - Only works for Azure workspaces (skips for non-Azure hosts) - Respects existing tenant ID if already configured ## Why are these changes needed? Currently, Azure Databricks users must manually specify the tenant-id when using Service Principal authentication. With this feature, users don't need to manually specify tenant ID, thus improving the user experience. ## How is this tested? ### Unit Tests **Comprehensive unit tests** in `AzureUtilsTest` covering: - ✅ Happy path (successful tenant ID extraction) - ✅ Error scenarios (404, missing Location header, malformed URLs) - ✅ Edge cases (non-Azure workspaces, already set tenant ID, missing host) - ✅ Integration with existing functionality ### Manual Testing - **✅ Locally built SDK JAR imported into JDBC driver** and tested connection without passing `azure_tenant_id` - **✅ Verified successful authentication** using auto-discovered tenant ID - **✅ Confirmed backward compatibility** - still works when explicit tenant ID is provided --------- Signed-off-by: Sreekanth Vadigi <[email protected]> Co-authored-by: Renaud Hartert <[email protected]>
1 parent 6cd3cf6 commit fe28261

File tree

4 files changed

+268
-6
lines changed

4 files changed

+268
-6
lines changed

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Release v0.60.0
44

55
### New Features and Improvements
6+
- Azure Service Principal credential provider can now automatically discover tenant ID when not explicitly provided
67

78
### Bug Fixes
89

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55
import com.fasterxml.jackson.databind.ObjectMapper;
66
import java.util.HashMap;
77
import java.util.Map;
8+
import org.slf4j.Logger;
9+
import org.slf4j.LoggerFactory;
810

911
/**
1012
* Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request,
1113
* while automatically resolving different Azure environment endpoints.
1214
*/
1315
public class AzureServicePrincipalCredentialsProvider implements CredentialsProvider {
16+
private static final Logger logger =
17+
LoggerFactory.getLogger(AzureServicePrincipalCredentialsProvider.class);
1418
private final ObjectMapper mapper = new ObjectMapper();
19+
private String tenantId;
1520

1621
@Override
1722
public String authType() {
@@ -22,12 +27,22 @@ public String authType() {
2227
public OAuthHeaderFactory configure(DatabricksConfig config) {
2328
if (!config.isAzure()
2429
|| config.getAzureClientId() == null
25-
|| config.getAzureClientSecret() == null
26-
|| config.getAzureTenantId() == null) {
30+
|| config.getAzureClientSecret() == null) {
2731
return null;
2832
}
29-
AzureUtils.ensureHostPresent(
30-
config, mapper, AzureServicePrincipalCredentialsProvider::tokenSourceFor);
33+
34+
try {
35+
this.tenantId =
36+
config.getAzureTenantId() != null
37+
? config.getAzureTenantId()
38+
: AzureUtils.inferTenantId(config);
39+
} catch (Exception e) {
40+
logger.warn("Failed to infer Azure tenant ID: {}", e.getMessage());
41+
return null;
42+
}
43+
44+
AzureUtils.ensureHostPresent(config, mapper, this::tokenSourceFor);
45+
3146
CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
3247
CachedTokenSource cloud =
3348
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
@@ -55,9 +70,9 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
5570
* @return A CachedTokenSource instance capable of fetching OAuth tokens for the specified Azure
5671
* resource.
5772
*/
58-
private static CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
73+
private CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
5974
String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint();
60-
String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token";
75+
String tokenUrl = aadEndpoint + this.tenantId + "/oauth2/token";
6176
Map<String, String> endpointParams = new HashMap<>();
6277
endpointParams.put("resource", resource);
6378

databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010
import com.fasterxml.jackson.databind.ObjectMapper;
1111
import com.fasterxml.jackson.databind.node.ObjectNode;
1212
import java.io.IOException;
13+
import java.net.URL;
1314
import java.util.Map;
1415
import java.util.Optional;
1516
import java.util.function.BiFunction;
1617

1718
public class AzureUtils {
1819

20+
/** Azure authentication endpoint for tenant ID discovery */
21+
private static final String AZURE_AUTH_ENDPOINT = "/aad/auth";
22+
1923
public static String getWorkspaceFromJsonResponse(ObjectNode jsonResponse) throws IOException {
2024
JsonNode properties = jsonResponse.get("properties");
2125
if (properties == null) {
@@ -95,4 +99,74 @@ public static Optional<String> getAzureWorkspaceResourceId(Workspace workspace)
9599
workspace.getWorkspaceName());
96100
return Optional.of(resourceId);
97101
}
102+
103+
/**
104+
* Infers the Azure tenant ID from the Databricks workspace login page.
105+
*
106+
* @param config The DatabricksConfig instance
107+
* @return the discovered tenant ID
108+
* @throws DatabricksException if tenant ID discovery fails
109+
*/
110+
public static String inferTenantId(DatabricksConfig config) throws DatabricksException {
111+
112+
if (config.getAzureTenantId() != null) {
113+
return config.getAzureTenantId();
114+
}
115+
116+
if (config.getHost() == null) {
117+
throw new DatabricksException("Cannot infer tenant ID: host is missing");
118+
}
119+
120+
if (!config.isAzure()) {
121+
throw new DatabricksException("Cannot infer tenant ID: workspace is not Azure");
122+
}
123+
124+
String loginUrl = config.getHost() + AZURE_AUTH_ENDPOINT;
125+
126+
try {
127+
String redirectLocation = getRedirectLocation(config, loginUrl);
128+
return extractTenantIdFromUrl(redirectLocation);
129+
130+
} catch (Exception e) {
131+
throw new DatabricksException("Failed to infer Azure tenant ID from " + loginUrl, e);
132+
}
133+
}
134+
135+
private static String getRedirectLocation(DatabricksConfig config, String loginUrl)
136+
throws IOException {
137+
Request request = new Request("GET", loginUrl);
138+
request.setRedirectionBehavior(false);
139+
Response response = config.getHttpClient().execute(request);
140+
141+
if (response.getStatusCode() != 302) {
142+
throw new DatabricksException(
143+
"Expected redirect (302) from "
144+
+ loginUrl
145+
+ ", got status code: "
146+
+ response.getStatusCode());
147+
}
148+
149+
String location = response.getFirstHeader("Location");
150+
if (location == null) {
151+
throw new DatabricksException("No Location header in redirect response from " + loginUrl);
152+
}
153+
154+
return location;
155+
}
156+
157+
private static String extractTenantIdFromUrl(String redirectUrl) throws DatabricksException {
158+
try {
159+
// Parse: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
160+
URL entraIdUrl = new URL(redirectUrl);
161+
String[] pathSegments = entraIdUrl.getPath().split("/");
162+
163+
if (pathSegments.length < 2) {
164+
throw new DatabricksException("Invalid path in Location header: " + entraIdUrl.getPath());
165+
}
166+
167+
return pathSegments[1];
168+
} catch (Exception e) {
169+
throw new DatabricksException("Failed to parse tenant ID from URL " + redirectUrl, e);
170+
}
171+
}
98172
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package com.databricks.sdk.core.utils;
2+
3+
import static org.junit.jupiter.api.Assertions.*;
4+
5+
import com.databricks.sdk.core.DatabricksConfig;
6+
import com.databricks.sdk.core.DatabricksException;
7+
import com.databricks.sdk.core.FixtureServer;
8+
import com.databricks.sdk.core.commons.CommonsHttpClient;
9+
import java.io.IOException;
10+
import org.junit.jupiter.api.Test;
11+
12+
public class AzureUtilsTest {
13+
14+
@Test
15+
public void testInferTenantId404() throws IOException {
16+
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 404)) {
17+
DatabricksConfig config = new DatabricksConfig();
18+
config.setHost(server.getUrl());
19+
config.setAzureWorkspaceResourceId(
20+
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
21+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
22+
23+
DatabricksException exception =
24+
assertThrows(
25+
DatabricksException.class,
26+
() -> {
27+
AzureUtils.inferTenantId(config);
28+
});
29+
assertEquals(
30+
"Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth",
31+
exception.getMessage());
32+
33+
assertNotNull(exception.getCause());
34+
assertInstanceOf(DatabricksException.class, exception.getCause());
35+
DatabricksException cause = (DatabricksException) exception.getCause();
36+
assertEquals(
37+
"Expected redirect (302) from " + server.getUrl() + "/aad/auth, got status code: 404",
38+
cause.getMessage());
39+
40+
assertNull(config.getAzureTenantId());
41+
}
42+
}
43+
44+
@Test
45+
public void testInferTenantIdNoLocationHeader() throws IOException {
46+
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 302)) {
47+
DatabricksConfig config = new DatabricksConfig();
48+
config.setHost(server.getUrl());
49+
config.setAzureWorkspaceResourceId(
50+
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
51+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
52+
53+
DatabricksException exception =
54+
assertThrows(
55+
DatabricksException.class,
56+
() -> {
57+
AzureUtils.inferTenantId(config);
58+
});
59+
assertEquals(
60+
"Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth",
61+
exception.getMessage());
62+
63+
assertNotNull(exception.getCause());
64+
assertInstanceOf(DatabricksException.class, exception.getCause());
65+
DatabricksException cause = (DatabricksException) exception.getCause();
66+
assertEquals(
67+
"No Location header in redirect response from " + server.getUrl() + "/aad/auth",
68+
cause.getMessage());
69+
70+
assertNull(config.getAzureTenantId());
71+
}
72+
}
73+
74+
@Test
75+
public void testInferTenantIdUnparsableLocationHeader() throws IOException {
76+
FixtureServer.FixtureMapping fixture =
77+
new FixtureServer.FixtureMapping.Builder()
78+
.validateMethod("GET")
79+
.validatePath("/aad/auth")
80+
.withRedirect("https://unexpected-location", 302)
81+
.build();
82+
83+
try (FixtureServer server = new FixtureServer().with(fixture)) {
84+
DatabricksConfig config = new DatabricksConfig();
85+
config.setHost(server.getUrl());
86+
config.setAzureWorkspaceResourceId(
87+
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
88+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
89+
90+
DatabricksException exception =
91+
assertThrows(
92+
DatabricksException.class,
93+
() -> {
94+
AzureUtils.inferTenantId(config);
95+
});
96+
assertEquals(
97+
"Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth",
98+
exception.getMessage());
99+
100+
assertNotNull(exception.getCause());
101+
assertInstanceOf(DatabricksException.class, exception.getCause());
102+
DatabricksException cause = (DatabricksException) exception.getCause();
103+
assertEquals(
104+
"Failed to parse tenant ID from URL https://unexpected-location", cause.getMessage());
105+
106+
assertNull(config.getAzureTenantId());
107+
}
108+
}
109+
110+
@Test
111+
public void testInferTenantIdHappyPath() throws IOException {
112+
FixtureServer.FixtureMapping fixture =
113+
new FixtureServer.FixtureMapping.Builder()
114+
.validateMethod("GET")
115+
.validatePath("/aad/auth")
116+
.withRedirect("https://login.microsoftonline.com/test-tenant-id/oauth2/authorize", 302)
117+
.build();
118+
119+
try (FixtureServer server = new FixtureServer().with(fixture)) {
120+
DatabricksConfig config = new DatabricksConfig();
121+
config.setHost(server.getUrl());
122+
config.setAzureWorkspaceResourceId(
123+
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
124+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
125+
String result = AzureUtils.inferTenantId(config);
126+
assertEquals("test-tenant-id", result);
127+
assertNull(config.getAzureTenantId()); // Config should remain unchanged
128+
}
129+
}
130+
131+
@Test
132+
public void testInferTenantIdSkipsWhenNotAzure() {
133+
DatabricksConfig config = new DatabricksConfig();
134+
config.setHost("https://my-workspace.cloud.databricks.com"); // non-azure host
135+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
136+
137+
DatabricksException exception =
138+
assertThrows(
139+
DatabricksException.class,
140+
() -> {
141+
AzureUtils.inferTenantId(config);
142+
});
143+
assertEquals("Cannot infer tenant ID: workspace is not Azure", exception.getMessage());
144+
assertNull(config.getAzureTenantId());
145+
}
146+
147+
@Test
148+
public void testInferTenantIdSkipsWhenAlreadySet() {
149+
DatabricksConfig config = new DatabricksConfig();
150+
config.setHost("https://adb-123.0.azuredatabricks.net");
151+
config.setAzureTenantId("existing-tenant-id");
152+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
153+
String result = AzureUtils.inferTenantId(config);
154+
assertEquals("existing-tenant-id", result);
155+
assertEquals("existing-tenant-id", config.getAzureTenantId()); // Config should remain unchanged
156+
}
157+
158+
@Test
159+
public void testInferTenantIdSkipsWhenNoHost() {
160+
DatabricksConfig config = new DatabricksConfig();
161+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
162+
163+
DatabricksException exception =
164+
assertThrows(
165+
DatabricksException.class,
166+
() -> {
167+
AzureUtils.inferTenantId(config);
168+
});
169+
assertEquals("Cannot infer tenant ID: host is missing", exception.getMessage());
170+
assertNull(config.getAzureTenantId());
171+
}
172+
}

0 commit comments

Comments
 (0)