Skip to content

Commit cf281fd

Browse files
committed
moved inferring logic to azure utils
Signed-off-by: Sreekanth Vadigi <[email protected]>
1 parent f06c006 commit cf281fd

File tree

5 files changed

+201
-196
lines changed

5 files changed

+201
-196
lines changed

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java

Lines changed: 1 addition & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,11 @@
1414
import java.io.File;
1515
import java.io.IOException;
1616
import java.lang.reflect.Field;
17-
import java.net.URL;
1817
import java.util.*;
1918
import org.apache.http.HttpMessage;
20-
import org.slf4j.Logger;
21-
import org.slf4j.LoggerFactory;
2219

2320
public class DatabricksConfig {
2421

25-
private static final Logger logger = LoggerFactory.getLogger(DatabricksConfig.class);
26-
27-
/** Azure authentication endpoint for tenant ID discovery */
28-
private static final String AZURE_AUTH_ENDPOINT = "/aad/auth";
29-
3022
private CredentialsProvider credentialsProvider = new DefaultCredentialsProvider();
3123

3224
@ConfigAttribute(env = "DATABRICKS_HOST")
@@ -735,7 +727,7 @@ private DatabricksConfig clone(Set<String> fieldsToSkip) {
735727
}
736728

737729
public DatabricksConfig clone() {
738-
return clone(new HashSet<>(Arrays.asList("logger", "AZURE_AUTH_ENDPOINT")));
730+
return clone(new HashSet<>());
739731
}
740732

741733
public DatabricksConfig newWithWorkspaceHost(String host) {
@@ -745,8 +737,6 @@ public DatabricksConfig newWithWorkspaceHost(String host) {
745737
// The config for WorkspaceClient has a different host and Azure Workspace resource
746738
// ID, and also omits
747739
// the account ID.
748-
"logger",
749-
"AZURE_AUTH_ENDPOINT",
750740
"host",
751741
"accountId",
752742
"azureWorkspaceResourceId",
@@ -766,88 +756,4 @@ public DatabricksConfig newWithWorkspaceHost(String host) {
766756
public String getEffectiveOAuthRedirectUrl() {
767757
return redirectUrl != null ? redirectUrl : "http://localhost:8080/callback";
768758
}
769-
770-
/**
771-
* [Internal] Load the Azure tenant ID from the Azure Databricks login page. If the tenant ID is
772-
* already set, this method does nothing.
773-
*
774-
* @return true if tenant ID is available (either was already set or successfully loaded), false otherwise
775-
*/
776-
public boolean loadAzureTenantId() {
777-
778-
if (azureTenantId != null) {
779-
return true; // Tenant ID already available - success
780-
}
781-
782-
if (!isAzure() || host == null) {
783-
return false; // Configuration issue - can't perform operation
784-
}
785-
786-
String loginUrl = host + AZURE_AUTH_ENDPOINT;
787-
logger.debug("Loading tenant ID from {}", loginUrl);
788-
789-
try {
790-
String redirectLocation = getRedirectLocation(loginUrl);
791-
if (redirectLocation == null) {
792-
return false; // Failed to get redirect location
793-
}
794-
795-
String extractedTenantId = extractTenantIdFromUrl(redirectLocation);
796-
if (extractedTenantId == null) {
797-
return false; // Failed to extract tenant ID
798-
}
799-
800-
this.azureTenantId = extractedTenantId;
801-
logger.debug("Loaded tenant ID: {}", this.azureTenantId);
802-
return true; // Successfully loaded
803-
804-
} catch (Exception e) {
805-
logger.warn("Failed to load tenant ID: {}", e.getMessage());
806-
return false;
807-
}
808-
}
809-
810-
private String getRedirectLocation(String loginUrl) throws IOException {
811-
812-
Request request = new Request("GET", loginUrl);
813-
request.setRedirectionBehavior(false);
814-
Response response = getHttpClient().execute(request);
815-
int statusCode = response.getStatusCode();
816-
817-
if (statusCode != 302) {
818-
logger.warn(
819-
"Failed to get tenant ID from {}: expected status code 302, got {}",
820-
loginUrl,
821-
statusCode);
822-
return null;
823-
}
824-
825-
String location = response.getFirstHeader("Location");
826-
if (location == null) {
827-
logger.warn("No Location header in response from {}", loginUrl);
828-
}
829-
830-
return location;
831-
}
832-
833-
private String extractTenantIdFromUrl(String redirectUrl) {
834-
try {
835-
// The Location header has the following form:
836-
// https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
837-
// The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US
838-
// Government cloud).
839-
URL entraIdUrl = new URL(redirectUrl);
840-
String[] pathSegments = entraIdUrl.getPath().split("/");
841-
842-
if (pathSegments.length < 2) {
843-
logger.warn("Invalid path in Location header: {}", entraIdUrl.getPath());
844-
return null;
845-
}
846-
847-
return pathSegments[1];
848-
} catch (Exception e) {
849-
logger.warn("Failed to extract tenant ID from URL {}: {}", redirectUrl, e.getMessage());
850-
return null;
851-
}
852-
}
853759
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
*/
1313
public class AzureServicePrincipalCredentialsProvider implements CredentialsProvider {
1414
private final ObjectMapper mapper = new ObjectMapper();
15+
private String tenantId;
1516

1617
@Override
1718
public String authType() {
@@ -25,14 +26,15 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
2526
|| config.getAzureClientSecret() == null) {
2627
return null;
2728
}
28-
AzureUtils.ensureHostPresent(
29-
config, mapper, AzureServicePrincipalCredentialsProvider::tokenSourceFor);
30-
31-
boolean tenantIdLoaded = config.loadAzureTenantId();
32-
if (!tenantIdLoaded) {
29+
30+
this.tenantId = config.getAzureTenantId() != null ? config.getAzureTenantId() : AzureUtils.inferTenantId(config);
31+
if (this.tenantId == null) {
3332
return null;
3433
}
35-
34+
35+
AzureUtils.ensureHostPresent(
36+
config, mapper, this::tokenSourceFor);
37+
3638
CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
3739
CachedTokenSource cloud =
3840
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
@@ -60,9 +62,9 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
6062
* @return A CachedTokenSource instance capable of fetching OAuth tokens for the specified Azure
6163
* resource.
6264
*/
63-
private static CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
65+
private CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
6466
String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint();
65-
String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token";
67+
String tokenUrl = aadEndpoint + this.tenantId + "/oauth2/token";
6668
Map<String, String> endpointParams = new HashMap<>();
6769
endpointParams.put("resource", resource);
6870

databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,20 @@
1010
import com.fasterxml.jackson.databind.ObjectMapper;
1111
import com.fasterxml.jackson.databind.node.ObjectNode;
1212
import java.io.IOException;
13+
import java.net.URL;
1314
import java.util.Map;
1415
import java.util.Optional;
1516
import java.util.function.BiFunction;
17+
import org.slf4j.Logger;
18+
import org.slf4j.LoggerFactory;
1619

1720
public class AzureUtils {
1821

22+
private static final Logger logger = LoggerFactory.getLogger(AzureUtils.class);
23+
24+
/** Azure authentication endpoint for tenant ID discovery */
25+
private static final String AZURE_AUTH_ENDPOINT = "/aad/auth";
26+
1927
public static String getWorkspaceFromJsonResponse(ObjectNode jsonResponse) throws IOException {
2028
JsonNode properties = jsonResponse.get("properties");
2129
if (properties == null) {
@@ -95,4 +103,80 @@ public static Optional<String> getAzureWorkspaceResourceId(Workspace workspace)
95103
workspace.getWorkspaceName());
96104
return Optional.of(resourceId);
97105
}
106+
107+
/**
108+
* Infers the Azure tenant ID from the Databricks workspace login page.
109+
*
110+
* @param config The DatabricksConfig instance
111+
* @return the discovered tenant ID, or null if discovery fails
112+
*/
113+
public static String inferTenantId(DatabricksConfig config) {
114+
if (config.getAzureTenantId() != null) {
115+
return config.getAzureTenantId();
116+
}
117+
118+
if (!config.isAzure() || config.getHost() == null) {
119+
logger.warn("Cannot infer tenant ID: workspace is not Azure or host is missing");
120+
return null;
121+
}
122+
123+
String loginUrl = config.getHost() + AZURE_AUTH_ENDPOINT;
124+
125+
try {
126+
String redirectLocation = getRedirectLocation(config, loginUrl);
127+
if (redirectLocation == null) {
128+
logger.warn("Failed to get redirect location from Azure auth endpoint: {}", loginUrl);
129+
return null;
130+
}
131+
132+
String extractedTenantId = extractTenantIdFromUrl(redirectLocation);
133+
if (extractedTenantId == null) {
134+
logger.warn("Failed to extract tenant ID from redirect URL: {}", redirectLocation);
135+
return null;
136+
}
137+
138+
logger.info("Successfully discovered Azure tenant ID: {}", extractedTenantId);
139+
return extractedTenantId;
140+
141+
} catch (Exception e) {
142+
logger.warn("Exception occurred while inferring Azure tenant ID from {}: {}", loginUrl, e.getMessage());
143+
return null;
144+
}
145+
}
146+
147+
private static String getRedirectLocation(DatabricksConfig config, String loginUrl) throws IOException {
148+
Request request = new Request("GET", loginUrl);
149+
request.setRedirectionBehavior(false);
150+
Response response = config.getHttpClient().execute(request);
151+
152+
if (response.getStatusCode() != 302) {
153+
logger.warn("Expected redirect (302) from {}, got status code: {}", loginUrl, response.getStatusCode());
154+
return null;
155+
}
156+
157+
String location = response.getFirstHeader("Location");
158+
if (location == null) {
159+
logger.warn("No Location header in redirect response from {}", loginUrl);
160+
}
161+
162+
return location;
163+
}
164+
165+
private static String extractTenantIdFromUrl(String redirectUrl) {
166+
try {
167+
// Parse: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
168+
URL entraIdUrl = new URL(redirectUrl);
169+
String[] pathSegments = entraIdUrl.getPath().split("/");
170+
171+
if (pathSegments.length < 2) {
172+
logger.warn("Invalid path in Location header: {}", entraIdUrl.getPath());
173+
return null;
174+
}
175+
176+
return pathSegments[1];
177+
} catch (Exception e) {
178+
logger.warn("Failed to parse tenant ID from URL {}: {}", redirectUrl, e.getMessage());
179+
return null;
180+
}
181+
}
98182
}

databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java

Lines changed: 0 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -251,97 +251,4 @@ public void testGetTokenSourceWithOAuth() {
251251
assertEquals(tokenSource.getToken().getAccessToken(), "test-token");
252252
}
253253

254-
@Test
255-
public void testLoadAzureTenantId404() throws IOException {
256-
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 404)) {
257-
DatabricksConfig config = new DatabricksConfig();
258-
config.setHost(server.getUrl());
259-
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
260-
boolean result = config.loadAzureTenantId();
261-
assertFalse(result);
262-
assertNull(config.getAzureTenantId());
263-
}
264-
}
265-
266-
@Test
267-
public void testLoadAzureTenantIdNoLocationHeader() throws IOException {
268-
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 302)) {
269-
DatabricksConfig config = new DatabricksConfig();
270-
config.setHost(server.getUrl());
271-
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
272-
boolean result = config.loadAzureTenantId();
273-
assertFalse(result);
274-
assertNull(config.getAzureTenantId());
275-
}
276-
}
277-
278-
@Test
279-
public void testLoadAzureTenantIdUnparsableLocationHeader() throws IOException {
280-
FixtureServer.FixtureMapping fixture =
281-
new FixtureServer.FixtureMapping.Builder()
282-
.validateMethod("GET")
283-
.validatePath("/aad/auth")
284-
.withRedirect("https://unexpected-location", 302)
285-
.build();
286-
287-
try (FixtureServer server = new FixtureServer().with(fixture)) {
288-
DatabricksConfig config = new DatabricksConfig();
289-
config.setHost(server.getUrl());
290-
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
291-
boolean result = config.loadAzureTenantId();
292-
assertFalse(result);
293-
assertNull(config.getAzureTenantId());
294-
}
295-
}
296-
297-
@Test
298-
public void testLoadAzureTenantIdHappyPath() throws IOException {
299-
FixtureServer.FixtureMapping fixture =
300-
new FixtureServer.FixtureMapping.Builder()
301-
.validateMethod("GET")
302-
.validatePath("/aad/auth")
303-
.withRedirect("https://login.microsoftonline.com/test-tenant-id/oauth2/authorize", 302)
304-
.build();
305-
306-
try (FixtureServer server = new FixtureServer().with(fixture)) {
307-
DatabricksConfig config = new DatabricksConfig();
308-
config.setHost(server.getUrl());
309-
config.setAzureWorkspaceResourceId(
310-
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
311-
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
312-
boolean result = config.loadAzureTenantId();
313-
assertTrue(result);
314-
assertEquals("test-tenant-id", config.getAzureTenantId());
315-
}
316-
}
317-
318-
@Test
319-
public void testLoadAzureTenantIdSkipsWhenNotAzure() throws IOException {
320-
DatabricksConfig config = new DatabricksConfig();
321-
config.setHost("https://my-workspace.cloud.databricks.com"); // non-azure host
322-
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
323-
boolean result = config.loadAzureTenantId();
324-
assertFalse(result);
325-
assertNull(config.getAzureTenantId());
326-
}
327-
328-
@Test
329-
public void testLoadAzureTenantIdSkipsWhenAlreadySet() throws IOException {
330-
DatabricksConfig config = new DatabricksConfig();
331-
config.setHost("https://adb-123.0.azuredatabricks.net");
332-
config.setAzureTenantId("existing-tenant-id");
333-
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
334-
boolean result = config.loadAzureTenantId();
335-
assertTrue(result);
336-
assertEquals("existing-tenant-id", config.getAzureTenantId());
337-
}
338-
339-
@Test
340-
public void testLoadAzureTenantIdSkipsWhenNoHost() throws IOException {
341-
DatabricksConfig config = new DatabricksConfig();
342-
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
343-
boolean result = config.loadAzureTenantId();
344-
assertFalse(result);
345-
assertNull(config.getAzureTenantId());
346-
}
347254
}

0 commit comments

Comments
 (0)