Skip to content

Commit f5bde48

Browse files
Support Databricks Workload Identity Federation for GitHub tokens (#423)
## What changes are proposed in this pull request? This PR adds support for Databricks Workload Identity Federation using GitHub tokens. This allows users to use WIF from their GitHub Workflows and authenticate their workloads without long lived secrets. This new credentials strategy is added to the DefaultCredentialsStrategy after the other Databricks Credentials Strategy and before cloud specific authentication methods. WIF credentials uses a subset of configuration values of other Databricks authentication methods. By being added after them it ensures that WIF is not used when other Databricks authentication methods are configured. WIF uses the Databricks client id, which is not used by cloud specific authentication methods. Therefore, it will not be used when cloud specific authentication methods are configured. ## How is this tested? Added tests.
1 parent d854b9c commit f5bde48

File tree

9 files changed

+302
-21
lines changed

9 files changed

+302
-21
lines changed

NEXT_CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
## Release v0.48.0
44

55
### New Features and Improvements
6+
* Introduce support for Databricks Workload Identity Federation in GitHub workflows ([423](https://github.com/databricks/databricks-sdk-java/pull/423)).
7+
See README.md for instructions.
8+
* [Breaking] Users running their workflows in GitHub Actions, which use Cloud native authentication and also have a `DATABRICKS_CLIENT_ID` and `DATABRICKS_HOST`
9+
environment variables set may see their authentication start failing due to the order in which the SDK tries different authentication methods.
610

711
### Bug Fixes
812

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,18 @@ Depending on the Databricks authentication method, the SDK uses the following in
116116

117117
### Databricks native authentication
118118

119-
By default, the Databricks SDK for Java initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks basic (username/password) authentication (`auth_type="basic"` argument).
119+
By default, the Databricks SDK for Java initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks Workload Identity Federation (WIF) authentication using OIDC (`auth_type="github-oidc"` argument).
120120

121121
- For Databricks token authentication, you must provide `host` and `token`; or their environment variable or `.databrickscfg` file field equivalents.
122-
- For Databricks basic authentication, you must provide `host`, `username`, and `password` _(for AWS workspace-level operations)_; or `host`, `account_id`, `username`, and `password` _(for AWS, Azure, or GCP account-level operations)_; or their environment variable or `.databrickscfg` file field equivalents.
122+
- For Databricks OIDC authentication, you must provide the `host`, `client_id` and `token_audience` _(optional)_ either directly, through the corresponding environment variables, or in your `.databrickscfg` configuration file.
123123

124124
| Argument | Description | Environment variable |
125125
|--------------|-------------|-------------------|
126126
| `host` | _(String)_ The Databricks host URL for either the Databricks workspace endpoint or the Databricks accounts endpoint. | `DATABRICKS_HOST` |
127127
| `account_id` | _(String)_ The Databricks account ID for the Databricks accounts endpoint. Only has effect when `Host` is either `https://accounts.cloud.databricks.com/` _(AWS)_, `https://accounts.azuredatabricks.net/` _(Azure)_, or `https://accounts.gcp.databricks.com/` _(GCP)_. | `DATABRICKS_ACCOUNT_ID` |
128128
| `token` | _(String)_ The Databricks personal access token (PAT) _(AWS, Azure, and GCP)_ or Azure Active Directory (Azure AD) token _(Azure)_. | `DATABRICKS_TOKEN` |
129-
| `username` | _(String)_ The Databricks username part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_USERNAME` |
130-
| `password` | _(String)_ The Databricks password part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_PASSWORD` |
129+
| `client_id` | _(String)_ The Databricks Service Principal Application ID. | `DATABRICKS_CLIENT_ID` |
130+
| `token_audience` | _(String)_ When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier. | `TOKEN_AUDIENCE` |
131131

132132
For example, to use Databricks token authentication:
133133

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ public class DatabricksConfig {
141141

142142
private DatabricksEnvironment databricksEnvironment;
143143

144+
/**
145+
* When using Workload Identity Federation, the audience to specify when fetching an ID token from
146+
* the ID token supplier.
147+
*/
148+
@ConfigAttribute(env = "TOKEN_AUDIENCE")
149+
private String tokenAudience;
150+
144151
public Environment getEnv() {
145152
return env;
146153
}
@@ -512,6 +519,15 @@ public DatabricksConfig setHttpClient(HttpClient httpClient) {
512519
return this;
513520
}
514521

522+
public String getTokenAudience() {
523+
return tokenAudience;
524+
}
525+
526+
public DatabricksConfig setTokenAudience(String tokenAudience) {
527+
this.tokenAudience = tokenAudience;
528+
return this;
529+
}
530+
515531
public boolean isAzure() {
516532
if (azureWorkspaceResourceId != null) {
517533
return true;

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package com.databricks.sdk.core;
22

3-
import com.databricks.sdk.core.oauth.AzureGithubOidcCredentialsProvider;
4-
import com.databricks.sdk.core.oauth.AzureServicePrincipalCredentialsProvider;
5-
import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider;
6-
import com.databricks.sdk.core.oauth.OAuthM2MServicePrincipalCredentialsProvider;
3+
import com.databricks.sdk.core.oauth.*;
74
import java.util.ArrayList;
85
import java.util.Arrays;
96
import java.util.List;
@@ -18,6 +15,7 @@ public class DefaultCredentialsProvider implements CredentialsProvider {
1815
PatCredentialsProvider.class,
1916
BasicCredentialsProvider.class,
2017
OAuthM2MServicePrincipalCredentialsProvider.class,
18+
GithubOidcCredentialsProvider.class,
2119
AzureGithubOidcCredentialsProvider.class,
2220
AzureServicePrincipalCredentialsProvider.class,
2321
AzureCliCredentialsProvider.class,

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private static RefreshableTokenSource tokenSourceFor(DatabricksConfig config, St
6363
.withClientId(config.getAzureClientId())
6464
.withClientSecret(config.getAzureClientSecret())
6565
.withTokenUrl(tokenUrl)
66-
.withEndpointParameters(endpointParams)
66+
.withEndpointParametersSupplier(() -> endpointParams)
6767
.withAuthParameterPosition(AuthParameterPosition.BODY)
6868
.build();
6969
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.databricks.sdk.core.commons.CommonsHttpClient;
44
import com.databricks.sdk.core.http.HttpClient;
55
import java.util.*;
6+
import java.util.function.Supplier;
67

78
/**
89
* An implementation of RefreshableTokenSource implementing the client_credentials OAuth grant type.
@@ -18,7 +19,11 @@ public static class Builder {
1819
private String clientSecret;
1920
private String tokenUrl;
2021
private HttpClient hc = new CommonsHttpClient.Builder().withTimeoutSeconds(30).build();
21-
private Map<String, String> endpointParams = Collections.emptyMap();
22+
23+
// Endpoint parameters can include tokens with expiration which
24+
// may need to be refreshed. This supplier will be called each time
25+
// the credentials are refreshed.
26+
private Supplier<Map<String, String>> endpointParamsSupplier = null;
2227
private List<String> scopes = Collections.emptyList();
2328
private AuthParameterPosition position = AuthParameterPosition.BODY;
2429

@@ -32,13 +37,14 @@ public Builder withClientSecret(String clientSecret) {
3237
return this;
3338
}
3439

35-
public Builder withTokenUrl(String tokenUrl) {
36-
this.tokenUrl = tokenUrl;
40+
public Builder withEndpointParametersSupplier(
41+
Supplier<Map<String, String>> endpointParamsSupplier) {
42+
this.endpointParamsSupplier = endpointParamsSupplier;
3743
return this;
3844
}
3945

40-
public Builder withEndpointParameters(Map<String, String> params) {
41-
this.endpointParams = params;
46+
public Builder withTokenUrl(String tokenUrl) {
47+
this.tokenUrl = tokenUrl;
4248
return this;
4349
}
4450

@@ -59,34 +65,33 @@ public Builder withHttpClient(HttpClient hc) {
5965

6066
public ClientCredentials build() {
6167
Objects.requireNonNull(this.clientId, "clientId must be specified");
62-
Objects.requireNonNull(this.clientSecret, "clientSecret must be specified");
6368
Objects.requireNonNull(this.tokenUrl, "tokenUrl must be specified");
6469
return new ClientCredentials(
65-
hc, clientId, clientSecret, tokenUrl, endpointParams, scopes, position);
70+
hc, clientId, clientSecret, tokenUrl, endpointParamsSupplier, scopes, position);
6671
}
6772
}
6873

6974
private HttpClient hc;
7075
private String clientId;
7176
private String clientSecret;
7277
private String tokenUrl;
73-
private Map<String, String> endpointParams;
7478
private List<String> scopes;
7579
private AuthParameterPosition position;
80+
private Supplier<Map<String, String>> endpointParamsSupplier;
7681

7782
private ClientCredentials(
7883
HttpClient hc,
7984
String clientId,
8085
String clientSecret,
8186
String tokenUrl,
82-
Map<String, String> endpointParams,
87+
Supplier<Map<String, String>> endpointParamsSupplier,
8388
List<String> scopes,
8489
AuthParameterPosition position) {
8590
this.hc = hc;
8691
this.clientId = clientId;
8792
this.clientSecret = clientSecret;
8893
this.tokenUrl = tokenUrl;
89-
this.endpointParams = endpointParams;
94+
this.endpointParamsSupplier = endpointParamsSupplier;
9095
this.scopes = scopes;
9196
this.position = position;
9297
}
@@ -98,8 +103,8 @@ protected Token refresh() {
98103
if (scopes != null) {
99104
params.put("scope", String.join(" ", scopes));
100105
}
101-
if (endpointParams != null) {
102-
params.putAll(endpointParams);
106+
if (endpointParamsSupplier != null) {
107+
params.putAll(endpointParamsSupplier.get());
103108
}
104109
return retrieveToken(hc, clientId, clientSecret, tokenUrl, params, new HashMap<>(), position);
105110
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.DatabricksException;
4+
import com.databricks.sdk.core.http.HttpClient;
5+
import com.databricks.sdk.core.http.Request;
6+
import com.databricks.sdk.core.http.Response;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import com.fasterxml.jackson.databind.node.ObjectNode;
9+
import java.io.IOException;
10+
11+
public class GitHubOidcTokenSupplier {
12+
13+
private final ObjectMapper mapper = new ObjectMapper();
14+
private final HttpClient httpClient;
15+
private final String idTokenRequestUrl;
16+
private final String idTokenRequestToken;
17+
private final String tokenAudience;
18+
19+
public GitHubOidcTokenSupplier(
20+
HttpClient httpClient,
21+
String idTokenRequestUrl,
22+
String idTokenRequestToken,
23+
String tokenAudience) {
24+
this.httpClient = httpClient;
25+
this.idTokenRequestUrl = idTokenRequestUrl;
26+
this.idTokenRequestToken = idTokenRequestToken;
27+
this.tokenAudience = tokenAudience;
28+
}
29+
30+
/** Checks if the required parameters are present to request a GitHub's OIDC token. */
31+
public Boolean enabled() {
32+
return idTokenRequestUrl != null && idTokenRequestToken != null;
33+
}
34+
35+
/**
36+
* Requests a GitHub's OIDC token.
37+
*
38+
* @return A GitHub OIDC token.
39+
*/
40+
public String getOidcToken() {
41+
if (!enabled()) {
42+
throw new DatabricksException("Failed to request ID token: missing required parameters");
43+
}
44+
45+
String requestUrl = idTokenRequestUrl;
46+
if (tokenAudience != null) {
47+
requestUrl += "&audience=" + tokenAudience;
48+
}
49+
50+
Request req =
51+
new Request("GET", requestUrl).withHeader("Authorization", "Bearer " + idTokenRequestToken);
52+
53+
Response resp;
54+
try {
55+
resp = httpClient.execute(req);
56+
} catch (IOException e) {
57+
throw new DatabricksException(
58+
"Failed to request ID token from " + requestUrl + ":" + e.getMessage(), e);
59+
}
60+
61+
if (resp.getStatusCode() != 200) {
62+
throw new DatabricksException(
63+
"Failed to request ID token: status code "
64+
+ resp.getStatusCode()
65+
+ ", response body: "
66+
+ resp.getBody().toString());
67+
}
68+
69+
ObjectNode jsonResp;
70+
try {
71+
jsonResp = mapper.readValue(resp.getBody(), ObjectNode.class);
72+
} catch (IOException e) {
73+
throw new DatabricksException(
74+
"Failed to request ID token: corrupted token: " + e.getMessage());
75+
}
76+
77+
return jsonResp.get("value").textValue();
78+
}
79+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.CredentialsProvider;
4+
import com.databricks.sdk.core.DatabricksConfig;
5+
import com.databricks.sdk.core.DatabricksException;
6+
import com.databricks.sdk.core.HeaderFactory;
7+
import com.google.common.collect.ImmutableMap;
8+
import java.io.IOException;
9+
import java.util.Collections;
10+
import java.util.HashMap;
11+
import java.util.Map;
12+
13+
/**
14+
* GithubOidcCredentialsProvider uses a Token Supplier to get a GitHub OIDC JWT Token and exchanges
15+
* it for a Databricks Token.
16+
*/
17+
public class GithubOidcCredentialsProvider implements CredentialsProvider {
18+
19+
@Override
20+
public String authType() {
21+
return "github-oidc";
22+
}
23+
24+
@Override
25+
public HeaderFactory configure(DatabricksConfig config) throws DatabricksException {
26+
GitHubOidcTokenSupplier idTokenProvider =
27+
new GitHubOidcTokenSupplier(
28+
config.getHttpClient(),
29+
config.getActionsIdTokenRequestUrl(),
30+
config.getActionsIdTokenRequestToken(),
31+
config.getTokenAudience());
32+
33+
if (!idTokenProvider.enabled() || config.getHost() == null || config.getClientId() == null) {
34+
return null;
35+
}
36+
37+
String endpointUrl;
38+
39+
try {
40+
endpointUrl = config.getOidcEndpoints().getTokenEndpoint();
41+
} catch (IOException e) {
42+
throw new DatabricksException("Unable to fetch OIDC endpoint: " + e.getMessage(), e);
43+
}
44+
45+
ClientCredentials clientCredentials =
46+
new ClientCredentials.Builder()
47+
.withHttpClient(config.getHttpClient())
48+
.withClientId(config.getClientId())
49+
.withTokenUrl(endpointUrl)
50+
.withScopes(Collections.singletonList("all-apis"))
51+
.withAuthParameterPosition(AuthParameterPosition.HEADER)
52+
.withEndpointParametersSupplier(
53+
() ->
54+
new ImmutableMap.Builder<String, String>()
55+
.put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt")
56+
.put("subject_token", idTokenProvider.getOidcToken())
57+
.put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
58+
.build())
59+
.build();
60+
61+
return () -> {
62+
Map<String, String> headers = new HashMap<>();
63+
headers.put("Authorization", "Bearer " + clientCredentials.getToken().getAccessToken());
64+
return headers;
65+
};
66+
}
67+
}

0 commit comments

Comments
 (0)