Skip to content

Commit 5a1fd5b

Browse files
committed
infer azure tenant id
Signed-off-by: Sreekanth Vadigi <[email protected]>
1 parent 3a1bbb1 commit 5a1fd5b

File tree

3 files changed

+180
-5
lines changed

3 files changed

+180
-5
lines changed

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@
1414
import java.io.File;
1515
import java.io.IOException;
1616
import java.lang.reflect.Field;
17+
import java.net.URL;
1718
import java.util.*;
1819
import org.apache.http.HttpMessage;
20+
import org.slf4j.Logger;
21+
import org.slf4j.LoggerFactory;
1922

2023
public class DatabricksConfig {
24+
25+
private static final Logger logger = LoggerFactory.getLogger(DatabricksConfig.class);
2126
private CredentialsProvider credentialsProvider = new DefaultCredentialsProvider();
2227

2328
@ConfigAttribute(env = "DATABRICKS_HOST")
@@ -414,13 +419,17 @@ public DatabricksConfig setAzureUseMsi(boolean azureUseMsi) {
414419
return this;
415420
}
416421

417-
/** @deprecated Use {@link #getAzureUseMsi()} instead. */
422+
/**
423+
* @deprecated Use {@link #getAzureUseMsi()} instead.
424+
*/
418425
@Deprecated()
419426
public boolean getAzureUseMSI() {
420427
return azureUseMsi;
421428
}
422429

423-
/** @deprecated Use {@link #setAzureUseMsi(boolean)} instead. */
430+
/**
431+
* @deprecated Use {@link #setAzureUseMsi(boolean)} instead.
432+
*/
424433
@Deprecated
425434
public DatabricksConfig setAzureUseMSI(boolean azureUseMsi) {
426435
this.azureUseMsi = azureUseMsi;
@@ -726,7 +735,7 @@ private DatabricksConfig clone(Set<String> fieldsToSkip) {
726735
}
727736

728737
public DatabricksConfig clone() {
729-
return clone(new HashSet<>());
738+
return clone(new HashSet<>(Collections.singletonList("logger")));
730739
}
731740

732741
public DatabricksConfig newWithWorkspaceHost(String host) {
@@ -736,6 +745,7 @@ public DatabricksConfig newWithWorkspaceHost(String host) {
736745
// The config for WorkspaceClient has a different host and Azure Workspace resource
737746
// ID, and also omits
738747
// the account ID.
748+
"logger",
739749
"host",
740750
"accountId",
741751
"azureWorkspaceResourceId",
@@ -755,4 +765,82 @@ public DatabricksConfig newWithWorkspaceHost(String host) {
755765
public String getEffectiveOAuthRedirectUrl() {
756766
return redirectUrl != null ? redirectUrl : "http://localhost:8080/callback";
757767
}
768+
769+
private static final String AZURE_AUTH_ENDPOINT = "/aad/auth";
770+
771+
/**
772+
* [Internal] Load the Azure tenant ID from the Azure Databricks login page. If the tenant ID is
773+
* already set, this method does nothing.
774+
*/
775+
public void loadAzureTenantId() {
776+
777+
if (!isAzure() || azureTenantId != null || host == null) {
778+
return;
779+
}
780+
781+
String loginUrl = host + AZURE_AUTH_ENDPOINT;
782+
logger.debug("Loading tenant ID from {}", loginUrl);
783+
784+
try {
785+
String redirectLocation = getRedirectLocation(loginUrl);
786+
if (redirectLocation == null) {
787+
return;
788+
}
789+
790+
String extractedTenantId = extractTenantIdFromUrl(redirectLocation);
791+
if (extractedTenantId == null) {
792+
return;
793+
}
794+
795+
this.azureTenantId = extractedTenantId;
796+
logger.debug("Loaded tenant ID: {}", this.azureTenantId);
797+
798+
} catch (Exception e) {
799+
logger.warn("Failed to load tenant ID: {}", e.getMessage());
800+
}
801+
}
802+
803+
private String getRedirectLocation(String loginUrl) throws IOException {
804+
805+
Request request = new Request("GET", loginUrl);
806+
request.setRedirectionBehavior(false);
807+
Response response = getHttpClient().execute(request);
808+
int statusCode = response.getStatusCode();
809+
810+
if (statusCode / 100 != 3) {
811+
logger.warn(
812+
"Failed to get tenant ID from {}: expected status code 3xx, got {}",
813+
loginUrl,
814+
statusCode);
815+
return null;
816+
}
817+
818+
String location = response.getFirstHeader("Location");
819+
if (location == null) {
820+
logger.warn("No Location header in response from {}", loginUrl);
821+
}
822+
823+
return location;
824+
}
825+
826+
private String extractTenantIdFromUrl(String redirectUrl) {
827+
try {
828+
// The Location header has the following form:
829+
// https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
830+
// The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US
831+
// Government cloud).
832+
URL entraIdUrl = new URL(redirectUrl);
833+
String[] pathSegments = entraIdUrl.getPath().split("/");
834+
835+
if (pathSegments.length < 2) {
836+
logger.warn("Invalid path in Location header: {}", entraIdUrl.getPath());
837+
return null;
838+
}
839+
840+
return pathSegments[1];
841+
} catch (Exception e) {
842+
logger.warn("Failed to extract tenant ID from URL {}: {}", redirectUrl, e.getMessage());
843+
return null;
844+
}
845+
}
758846
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ public String authType() {
2222
public OAuthHeaderFactory configure(DatabricksConfig config) {
2323
if (!config.isAzure()
2424
|| config.getAzureClientId() == null
25-
|| config.getAzureClientSecret() == null
26-
|| config.getAzureTenantId() == null) {
25+
|| config.getAzureClientSecret() == null) {
2726
return null;
2827
}
2928
AzureUtils.ensureHostPresent(
3029
config, mapper, AzureServicePrincipalCredentialsProvider::tokenSourceFor);
30+
config.loadAzureTenantId();
3131
CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
3232
CachedTokenSource cloud =
3333
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());

databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,4 +250,91 @@ public void testGetTokenSourceWithOAuth() {
250250
assertFalse(tokenSource instanceof ErrorTokenSource);
251251
assertEquals(tokenSource.getToken().getAccessToken(), "test-token");
252252
}
253+
254+
@Test
255+
public void testLoadAzureTenantId404() throws IOException {
256+
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 404)) {
257+
DatabricksConfig config = new DatabricksConfig();
258+
config.setHost(server.getUrl());
259+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
260+
config.loadAzureTenantId();
261+
assertNull(config.getAzureTenantId());
262+
}
263+
}
264+
265+
@Test
266+
public void testLoadAzureTenantIdNoLocationHeader() throws IOException {
267+
try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 302)) {
268+
DatabricksConfig config = new DatabricksConfig();
269+
config.setHost(server.getUrl());
270+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
271+
config.loadAzureTenantId();
272+
assertNull(config.getAzureTenantId());
273+
}
274+
}
275+
276+
@Test
277+
public void testLoadAzureTenantIdUnparsableLocationHeader() throws IOException {
278+
FixtureServer.FixtureMapping fixture =
279+
new FixtureServer.FixtureMapping.Builder()
280+
.validateMethod("GET")
281+
.validatePath("/aad/auth")
282+
.withRedirect("https://unexpected-location", 302)
283+
.build();
284+
285+
try (FixtureServer server = new FixtureServer().with(fixture)) {
286+
DatabricksConfig config = new DatabricksConfig();
287+
config.setHost(server.getUrl());
288+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
289+
config.loadAzureTenantId();
290+
assertNull(config.getAzureTenantId());
291+
}
292+
}
293+
294+
@Test
295+
public void testLoadAzureTenantIdHappyPath() throws IOException {
296+
FixtureServer.FixtureMapping fixture =
297+
new FixtureServer.FixtureMapping.Builder()
298+
.validateMethod("GET")
299+
.validatePath("/aad/auth")
300+
.withRedirect("https://login.microsoftonline.com/test-tenant-id/oauth2/authorize", 302)
301+
.build();
302+
303+
try (FixtureServer server = new FixtureServer().with(fixture)) {
304+
DatabricksConfig config = new DatabricksConfig();
305+
config.setHost(server.getUrl());
306+
config.setAzureWorkspaceResourceId(
307+
"/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws");
308+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
309+
config.loadAzureTenantId();
310+
assertEquals("test-tenant-id", config.getAzureTenantId());
311+
}
312+
}
313+
314+
@Test
315+
public void testLoadAzureTenantIdSkipsWhenNotAzure() throws IOException {
316+
DatabricksConfig config = new DatabricksConfig();
317+
config.setHost("https://my-workspace.cloud.databricks.com"); // non-azure host
318+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
319+
config.loadAzureTenantId();
320+
assertNull(config.getAzureTenantId());
321+
}
322+
323+
@Test
324+
public void testLoadAzureTenantIdSkipsWhenAlreadySet() throws IOException {
325+
DatabricksConfig config = new DatabricksConfig();
326+
config.setHost("https://adb-123.0.azuredatabricks.net");
327+
config.setAzureTenantId("existing-tenant-id");
328+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
329+
config.loadAzureTenantId();
330+
assertEquals("existing-tenant-id", config.getAzureTenantId());
331+
}
332+
333+
@Test
334+
public void testLoadAzureTenantIdSkipsWhenNoHost() throws IOException {
335+
DatabricksConfig config = new DatabricksConfig();
336+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
337+
config.loadAzureTenantId();
338+
assertNull(config.getAzureTenantId());
339+
}
253340
}

0 commit comments

Comments
 (0)