Skip to content

Commit 7bb4fd0

Browse files
authored
Set necessary headers when authenticating via Azure CLI (#136)
## Changes The Java SDK request authentication logic is inconsistent between the Azure login types: for service principal auth, the SDK correctly adds the X-Databricks-Azure-Workspace-Resource-Id when configured, but this is missed for Azure CLI auth. Additionally, when logging in via Azure CLI using a service principal, the service management token must also be fetched from the CLI. This PR fixes this by defining the logic to attach these header in a common function that is used by all Azure-specific authentication types. See databricks/databricks-sdk-go#584 for the same change in the Go SDK. See databricks/databricks-sdk-py#290 for the same changes in the Python SDK. ## Tests - [x] Unit tests to cover the two scenarios for Azure CLI w.r.t. management endpoint token fetching, and one to verify that X-Databricks-Azure-Workspace-Resource-Id is included when using Azure CLI.
1 parent e0174d0 commit 7bb4fd0

File tree

11 files changed

+176
-90
lines changed

11 files changed

+176
-90
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
fmt:
2+
mvn spotless:apply
3+

databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,24 @@ public HeaderFactory configure(DatabricksConfig config) {
3737
ensureHostPresent(config, mapper);
3838
String resource = config.getEffectiveAzureLoginAppId();
3939
CliTokenSource tokenSource = tokenSourceFor(config, resource);
40+
CliTokenSource mgmtTokenSource =
41+
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
4042
tokenSource.getToken(); // We need this for checking if Azure CLI is installed.
43+
try {
44+
mgmtTokenSource.getToken();
45+
} catch (Exception e) {
46+
LOG.debug("Not including service management token in headers", e);
47+
mgmtTokenSource = null;
48+
}
49+
CliTokenSource finalMgmtTokenSource = mgmtTokenSource;
4150
return () -> {
4251
Token token = tokenSource.getToken();
4352
Map<String, String> headers = new HashMap<>();
4453
headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken());
45-
return headers;
54+
if (finalMgmtTokenSource != null) {
55+
addSpManagementToken(finalMgmtTokenSource, headers);
56+
}
57+
return addWorkspaceResourceId(config, headers);
4658
};
4759
} catch (DatabricksException e) {
4860
String stderr = e.getMessage();

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,8 @@ public HeaderFactory configure(DatabricksConfig config) {
3434
return () -> {
3535
Map<String, String> headers = new HashMap<>();
3636
headers.put("Authorization", "Bearer " + inner.getToken().getAccessToken());
37-
headers.put("X-Databricks-Azure-SP-Management-Token", cloud.getToken().getAccessToken());
38-
if (config.getAzureWorkspaceResourceId() != null) {
39-
headers.put(
40-
"X-Databricks-Azure-Workspace-Resource-Id", config.getAzureWorkspaceResourceId());
41-
}
37+
addWorkspaceResourceId(config, headers);
38+
addSpManagementToken(cloud, headers);
4239
return headers;
4340
};
4441
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,18 @@ default void ensureHostPresent(DatabricksConfig config, ObjectMapper mapper) {
9292
throw new DatabricksException("Unable to fetch workspace URL: " + e.getMessage(), e);
9393
}
9494
}
95+
96+
default Map<String, String> addWorkspaceResourceId(
97+
DatabricksConfig config, Map<String, String> headers) {
98+
if (config.getAzureWorkspaceResourceId() != null) {
99+
headers.put("X-Databricks-Azure-Workspace-Resource-Id", config.getAzureWorkspaceResourceId());
100+
}
101+
return headers;
102+
}
103+
104+
default Map<String, String> addSpManagementToken(
105+
RefreshableTokenSource tokenSource, Map<String, String> headers) {
106+
headers.put("X-Databricks-Azure-SP-Management-Token", tokenSource.getToken().getAccessToken());
107+
return headers;
108+
}
95109
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package com.databricks.sdk;
2+
3+
import com.databricks.sdk.core.ConfigResolving;
4+
import com.databricks.sdk.core.DatabricksConfig;
5+
import com.databricks.sdk.core.utils.TestOSUtils;
6+
import java.util.Map;
7+
import org.junit.jupiter.api.Assertions;
8+
import org.junit.jupiter.api.Test;
9+
10+
public class DatabricksAuthManualTest implements ConfigResolving {
11+
@Test
12+
void azureCliWorkspaceHeaderPresent() {
13+
StaticEnv env =
14+
new StaticEnv()
15+
.with("HOME", TestOSUtils.resource("/testdata/azure"))
16+
.with("PATH", "testdata:/bin");
17+
String azureWorkspaceResourceId =
18+
"/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123";
19+
DatabricksConfig config =
20+
new DatabricksConfig()
21+
.setAuthType("azure-cli")
22+
.setHost("https://x")
23+
.setAzureWorkspaceResourceId(azureWorkspaceResourceId);
24+
resolveConfig(config, env);
25+
Map<String, String> headers = config.authenticate();
26+
Assertions.assertEquals(
27+
azureWorkspaceResourceId, headers.get("X-Databricks-Azure-Workspace-Resource-Id"));
28+
}
29+
30+
@Test
31+
void azureCliUserWithManagementAccess() {
32+
StaticEnv env =
33+
new StaticEnv()
34+
.with("HOME", TestOSUtils.resource("/testdata/azure"))
35+
.with("PATH", "testdata:/bin");
36+
String azureWorkspaceResourceId =
37+
"/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123";
38+
DatabricksConfig config =
39+
new DatabricksConfig()
40+
.setAuthType("azure-cli")
41+
.setHost("https://x")
42+
.setAzureWorkspaceResourceId(azureWorkspaceResourceId);
43+
resolveConfig(config, env);
44+
Map<String, String> headers = config.authenticate();
45+
Assertions.assertEquals("...", headers.get("X-Databricks-Azure-SP-Management-Token"));
46+
}
47+
48+
@Test
49+
void azureCliUserNoManagementAccess() {
50+
StaticEnv env =
51+
new StaticEnv()
52+
.with("HOME", TestOSUtils.resource("/testdata/azure"))
53+
.with("PATH", "testdata:/bin")
54+
.with("FAIL_IF", "https://management.core.windows.net/");
55+
String azureWorkspaceResourceId =
56+
"/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123";
57+
DatabricksConfig config =
58+
new DatabricksConfig()
59+
.setAuthType("azure-cli")
60+
.setHost("https://x")
61+
.setAzureWorkspaceResourceId(azureWorkspaceResourceId);
62+
resolveConfig(config, env);
63+
Map<String, String> headers = config.authenticate();
64+
Assertions.assertNull(headers.get("X-Databricks-Azure-SP-Management-Token"));
65+
}
66+
}

databricks-sdk-java/src/test/java/com/databricks/sdk/DatabricksAuthTest.java

Lines changed: 26 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,12 @@
1010
import com.databricks.sdk.core.utils.GitHubUtils;
1111
import com.databricks.sdk.core.utils.TestOSUtils;
1212
import java.io.File;
13-
import java.net.URL;
14-
import java.util.HashMap;
15-
import java.util.Map;
16-
import java.util.function.Supplier;
1713
import org.junit.jupiter.api.Test;
1814

19-
public class DatabricksAuthTest implements TestOSUtils, GitHubUtils, ConfigResolving {
20-
21-
private static String prefixPath;
15+
public class DatabricksAuthTest implements GitHubUtils, ConfigResolving {
2216

2317
public DatabricksAuthTest() {
2418
setPermissionOnTestAz();
25-
prefixPath = System.getProperty("user.dir") + getTestDir();
2619
}
2720

2821
@Test
@@ -209,7 +202,7 @@ public void testTestConfigConfigFile() {
209202
@Test
210203
public void testTestConfigConfigFileSkipDefaultProfileIfHostSpecified() {
211204
// Set environment variables
212-
StaticEnv env = new StaticEnv().with("HOME", resource("/testdata"));
205+
StaticEnv env = new StaticEnv().with("HOME", TestOSUtils.resource("/testdata"));
213206
raises(
214207
"default auth: cannot configure default credentials. Config: host=https://x",
215208
() -> {
@@ -222,7 +215,7 @@ public void testTestConfigConfigFileSkipDefaultProfileIfHostSpecified() {
222215
@Test
223216
public void testTestConfigConfigFileWithEmptyDefaultProfileSelectDefault() {
224217
// Set environment variables
225-
StaticEnv env = new StaticEnv().with("HOME", resource("/testdata/empty_default"));
218+
StaticEnv env = new StaticEnv().with("HOME", TestOSUtils.resource("/testdata/empty_default"));
226219
raises(
227220
"default auth: cannot configure default credentials",
228221
() -> {
@@ -238,7 +231,7 @@ public void testTestConfigConfigFileWithEmptyDefaultProfileSelectAbc() {
238231
StaticEnv env =
239232
new StaticEnv()
240233
.with("DATABRICKS_CONFIG_PROFILE", "abc")
241-
.with("HOME", resource("/testdata/empty_default"));
234+
.with("HOME", TestOSUtils.resource("/testdata/empty_default"));
242235
DatabricksConfig config = new DatabricksConfig();
243236
resolveConfig(config, env);
244237
config.authenticate();
@@ -250,7 +243,7 @@ public void testTestConfigConfigFileWithEmptyDefaultProfileSelectAbc() {
250243
@Test
251244
public void testTestConfigPatFromDatabricksCfg() {
252245
// Set environment variables
253-
StaticEnv env = new StaticEnv().with("HOME", resource("/testdata"));
246+
StaticEnv env = new StaticEnv().with("HOME", TestOSUtils.resource("/testdata"));
254247
DatabricksConfig config = new DatabricksConfig();
255248
resolveConfig(config, env);
256249
config.authenticate();
@@ -265,7 +258,7 @@ public void testTestConfigPatFromDatabricksCfgDotProfile() {
265258
StaticEnv env =
266259
new StaticEnv()
267260
.with("DATABRICKS_CONFIG_PROFILE", "pat.with.dot")
268-
.with("HOME", resource("/testdata"));
261+
.with("HOME", TestOSUtils.resource("/testdata"));
269262
DatabricksConfig config = new DatabricksConfig();
270263
resolveConfig(config, env);
271264
config.authenticate();
@@ -280,7 +273,7 @@ public void testTestConfigPatFromDatabricksCfgNohostProfile() {
280273
StaticEnv env =
281274
new StaticEnv()
282275
.with("DATABRICKS_CONFIG_PROFILE", "nohost")
283-
.with("HOME", resource("/testdata"));
276+
.with("HOME", TestOSUtils.resource("/testdata"));
284277
raises(
285278
"default auth: cannot configure default credentials. Config: token=***, profile=nohost. Env: DATABRICKS_CONFIG_PROFILE",
286279
() -> {
@@ -297,7 +290,7 @@ public void testTestConfigConfigProfileAndToken() {
297290
new StaticEnv()
298291
.with("DATABRICKS_CONFIG_PROFILE", "nohost")
299292
.with("DATABRICKS_TOKEN", "x")
300-
.with("HOME", resource("/testdata"));
293+
.with("HOME", TestOSUtils.resource("/testdata"));
301294
raises(
302295
"default auth: cannot configure default credentials. Config: token=***, profile=nohost. Env: DATABRICKS_TOKEN, DATABRICKS_CONFIG_PROFILE",
303296
() -> {
@@ -314,7 +307,7 @@ public void testTestConfigConfigProfileAndPassword() {
314307
new StaticEnv()
315308
.with("DATABRICKS_CONFIG_PROFILE", "nohost")
316309
.with("DATABRICKS_USERNAME", "x")
317-
.with("HOME", resource("/testdata"));
310+
.with("HOME", TestOSUtils.resource("/testdata"));
318311
raises(
319312
"validate: more than one authorization method configured: basic and pat. Config: token=***, username=x, profile=nohost. Env: DATABRICKS_USERNAME, DATABRICKS_CONFIG_PROFILE",
320313
() -> {
@@ -341,7 +334,9 @@ public void testTestConfigAzurePat() {
341334
public void testTestConfigAzureCliHost() {
342335
// Set environment variables
343336
StaticEnv env =
344-
new StaticEnv().with("HOME", resource("/testdata/azure")).with("PATH", "testdata:/bin");
337+
new StaticEnv()
338+
.with("HOME", TestOSUtils.resource("/testdata/azure"))
339+
.with("PATH", "testdata:/bin");
345340
DatabricksConfig config =
346341
new DatabricksConfig().setHost("x").setAzureWorkspaceResourceId("/sub/rg/ws");
347342
resolveConfig(config, env);
@@ -358,7 +353,7 @@ public void testTestConfigAzureCliHostFail() {
358353
StaticEnv env =
359354
new StaticEnv()
360355
.with("FAIL", "yes")
361-
.with("HOME", resource("/testdata/azure"))
356+
.with("HOME", TestOSUtils.resource("/testdata/azure"))
362357
.with("PATH", "testdata:/bin");
363358
raises(
364359
"default auth: azure-cli: cannot get access token: This is just a failing script.\n. Config: azure_workspace_resource_id=/sub/rg/ws",
@@ -374,7 +369,9 @@ public void testTestConfigAzureCliHostFail() {
374369
public void testTestConfigAzureCliHostAzNotInstalled() {
375370
// Set environment variables
376371
StaticEnv env =
377-
new StaticEnv().with("HOME", resource("/testdata/azure")).with("PATH", "whatever");
372+
new StaticEnv()
373+
.with("HOME", TestOSUtils.resource("/testdata/azure"))
374+
.with("PATH", "whatever");
378375
raises(
379376
"default auth: cannot configure default credentials. Config: azure_workspace_resource_id=/sub/rg/ws",
380377
() -> {
@@ -389,7 +386,9 @@ public void testTestConfigAzureCliHostAzNotInstalled() {
389386
public void testTestConfigAzureCliHostPatConflictWithConfigFilePresentWithoutDefaultProfile() {
390387
// Set environment variables
391388
StaticEnv env =
392-
new StaticEnv().with("HOME", resource("/testdata/azure")).with("PATH", "testdata:/bin");
389+
new StaticEnv()
390+
.with("HOME", TestOSUtils.resource("/testdata/azure"))
391+
.with("PATH", "testdata:/bin");
393392
raises(
394393
"validate: more than one authorization method configured: azure and pat. Config: token=***, azure_workspace_resource_id=/sub/rg/ws",
395394
() -> {
@@ -404,7 +403,9 @@ public void testTestConfigAzureCliHostPatConflictWithConfigFilePresentWithoutDef
404403
public void testTestConfigAzureCliHostAndResourceId() {
405404
// Set environment variables
406405
StaticEnv env =
407-
new StaticEnv().with("HOME", resource("/testdata")).with("PATH", "testdata:/bin");
406+
new StaticEnv()
407+
.with("HOME", TestOSUtils.resource("/testdata"))
408+
.with("PATH", "testdata:/bin");
408409
DatabricksConfig config =
409410
new DatabricksConfig().setHost("x").setAzureWorkspaceResourceId("/sub/rg/ws");
410411
resolveConfig(config, env);
@@ -421,7 +422,7 @@ public void testTestConfigAzureCliHostAndResourceIDConfigurationPrecedence() {
421422
StaticEnv env =
422423
new StaticEnv()
423424
.with("DATABRICKS_CONFIG_PROFILE", "justhost")
424-
.with("HOME", resource("/testdata/azure"))
425+
.with("HOME", TestOSUtils.resource("/testdata/azure"))
425426
.with("PATH", "testdata:/bin");
426427
DatabricksConfig config =
427428
new DatabricksConfig().setHost("x").setAzureWorkspaceResourceId("/sub/rg/ws");
@@ -439,7 +440,7 @@ public void testTestConfigAzureAndPasswordConflict() {
439440
StaticEnv env =
440441
new StaticEnv()
441442
.with("DATABRICKS_USERNAME", "x")
442-
.with("HOME", resource("/testdata/azure"))
443+
.with("HOME", TestOSUtils.resource("/testdata/azure"))
443444
.with("PATH", "testdata:/bin");
444445
raises(
445446
"validate: more than one authorization method configured: azure and basic. Config: host=x, username=x, azure_workspace_resource_id=/sub/rg/ws. Env: DATABRICKS_USERNAME",
@@ -457,7 +458,7 @@ public void testTestConfigCorruptConfig() {
457458
StaticEnv env =
458459
new StaticEnv()
459460
.with("DATABRICKS_CONFIG_PROFILE", "DEFAULT")
460-
.with("HOME", resource("/testdata/corrupt"));
461+
.with("HOME", TestOSUtils.resource("/testdata/corrupt"));
461462
raises(
462463
"resolve: testdata/corrupt/.databrickscfg has no DEFAULT profile configured. Config: profile=DEFAULT. Env: DATABRICKS_CONFIG_PROFILE",
463464
() -> {
@@ -484,31 +485,6 @@ public void testTestConfigAuthTypeFromEnv() {
484485
assertEquals("https://x", config.getHost());
485486
}
486487

487-
private String resource(String file) {
488-
URL resource = getClass().getResource(file);
489-
if (resource == null) {
490-
fail("Asset not found: " + file);
491-
}
492-
return resource.getFile();
493-
}
494-
495-
static class StaticEnv implements Supplier<Map<String, String>> {
496-
private final Map<String, String> env = new HashMap<>();
497-
498-
public StaticEnv with(String key, String value) {
499-
if (key.equals("PATH")) {
500-
value = prefixPath + value;
501-
}
502-
env.put(key, value);
503-
return this;
504-
}
505-
506-
@Override
507-
public Map<String, String> get() {
508-
return env;
509-
}
510-
}
511-
512488
private void raises(String contains, Runnable cb) {
513489
boolean raised = false;
514490
try {
@@ -521,7 +497,7 @@ private void raises(String contains, Runnable cb) {
521497
File.separator,
522498
"/"); // We would need to do this upstream also for making paths compatible with
523499
// windows
524-
message = message.replace(prefixPath, "");
500+
message = message.replace(StaticEnv.getPrefixPath(), "");
525501
if (!message.contains(contains)) {
526502
fail(String.format("Expected exception to contain '%s'", contains), e);
527503
}

0 commit comments

Comments
 (0)