1414import java .io .File ;
1515import java .io .IOException ;
1616import java .lang .reflect .Field ;
17+ import java .net .URL ;
1718import java .util .*;
1819import org .apache .http .HttpMessage ;
20+ import org .slf4j .Logger ;
21+ import org .slf4j .LoggerFactory ;
1922
2023public class DatabricksConfig {
24+
25+ private static final Logger logger = LoggerFactory .getLogger (DatabricksConfig .class );
2126 private CredentialsProvider credentialsProvider = new DefaultCredentialsProvider ();
2227
2328 @ ConfigAttribute (env = "DATABRICKS_HOST" )
@@ -414,13 +419,17 @@ public DatabricksConfig setAzureUseMsi(boolean azureUseMsi) {
414419 return this ;
415420 }
416421
417- /** @deprecated Use {@link #getAzureUseMsi()} instead. */
422+ /**
423+ * @deprecated Use {@link #getAzureUseMsi()} instead.
424+ */
418425 @ Deprecated ()
419426 public boolean getAzureUseMSI () {
420427 return azureUseMsi ;
421428 }
422429
423- /** @deprecated Use {@link #setAzureUseMsi(boolean)} instead. */
430+ /**
431+ * @deprecated Use {@link #setAzureUseMsi(boolean)} instead.
432+ */
424433 @ Deprecated
425434 public DatabricksConfig setAzureUseMSI (boolean azureUseMsi ) {
426435 this .azureUseMsi = azureUseMsi ;
@@ -726,7 +735,7 @@ private DatabricksConfig clone(Set<String> fieldsToSkip) {
726735 }
727736
728737 public DatabricksConfig clone () {
729- return clone (new HashSet <>());
738+ return clone (new HashSet <>(Collections . singletonList ( "logger" ) ));
730739 }
731740
732741 public DatabricksConfig newWithWorkspaceHost (String host ) {
@@ -736,6 +745,7 @@ public DatabricksConfig newWithWorkspaceHost(String host) {
736745 // The config for WorkspaceClient has a different host and Azure Workspace resource
737746 // ID, and also omits
738747 // the account ID.
748+ "logger" ,
739749 "host" ,
740750 "accountId" ,
741751 "azureWorkspaceResourceId" ,
@@ -755,4 +765,82 @@ public DatabricksConfig newWithWorkspaceHost(String host) {
755765 public String getEffectiveOAuthRedirectUrl () {
756766 return redirectUrl != null ? redirectUrl : "http://localhost:8080/callback" ;
757767 }
768+
769+ private static final String AZURE_AUTH_ENDPOINT = "/aad/auth" ;
770+
771+ /**
772+ * [Internal] Load the Azure tenant ID from the Azure Databricks login page. If the tenant ID is
773+ * already set, this method does nothing.
774+ */
775+ public void loadAzureTenantId () {
776+
777+ if (!isAzure () || azureTenantId != null || host == null ) {
778+ return ;
779+ }
780+
781+ String loginUrl = host + AZURE_AUTH_ENDPOINT ;
782+ logger .debug ("Loading tenant ID from {}" , loginUrl );
783+
784+ try {
785+ String redirectLocation = getRedirectLocation (loginUrl );
786+ if (redirectLocation == null ) {
787+ return ;
788+ }
789+
790+ String extractedTenantId = extractTenantIdFromUrl (redirectLocation );
791+ if (extractedTenantId == null ) {
792+ return ;
793+ }
794+
795+ this .azureTenantId = extractedTenantId ;
796+ logger .debug ("Loaded tenant ID: {}" , this .azureTenantId );
797+
798+ } catch (Exception e ) {
799+ logger .warn ("Failed to load tenant ID: {}" , e .getMessage ());
800+ }
801+ }
802+
803+ private String getRedirectLocation (String loginUrl ) throws IOException {
804+
805+ Request request = new Request ("GET" , loginUrl );
806+ request .setRedirectionBehavior (false );
807+ Response response = getHttpClient ().execute (request );
808+ int statusCode = response .getStatusCode ();
809+
810+ if (statusCode / 100 != 3 ) {
811+ logger .warn (
812+ "Failed to get tenant ID from {}: expected status code 3xx, got {}" ,
813+ loginUrl ,
814+ statusCode );
815+ return null ;
816+ }
817+
818+ String location = response .getFirstHeader ("Location" );
819+ if (location == null ) {
820+ logger .warn ("No Location header in response from {}" , loginUrl );
821+ }
822+
823+ return location ;
824+ }
825+
826+ private String extractTenantIdFromUrl (String redirectUrl ) {
827+ try {
828+ // The Location header has the following form:
829+ // https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
830+ // The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US
831+ // Government cloud).
832+ URL entraIdUrl = new URL (redirectUrl );
833+ String [] pathSegments = entraIdUrl .getPath ().split ("/" );
834+
835+ if (pathSegments .length < 2 ) {
836+ logger .warn ("Invalid path in Location header: {}" , entraIdUrl .getPath ());
837+ return null ;
838+ }
839+
840+ return pathSegments [1 ];
841+ } catch (Exception e ) {
842+ logger .warn ("Failed to extract tenant ID from URL {}: {}" , redirectUrl , e .getMessage ());
843+ return null ;
844+ }
845+ }
758846}
0 commit comments