Skip to content

Commit ed4992d

Browse files
authored
[Fix] Retry on too many auth requests (#355)
## Changes <!-- Summary of your changes that are easy to understand --> This PR addresses an issue encountered during authentication, where the server returns a 429 status code when fetching OIDC endpoints from a well-known location, particularly during frequent authentication requests. The update introduces a retry mechanism within the auth flow to mitigate the impact of 429 errors. Modified `ApiClient` object construction to allow unauthenticated calls by including authentication function injection, where within the OIDC workflow act as no-op actions. `DatabricksConfig` would now instantiate an `ApiClient` instead of a standard `CommonsHttpClient` in these cases. errors during a complete authentication workflow, especially under high concurrent request conditions. ## Tests <!-- How is this tested? --> * Added unit tests to verify the retry functionality upon encountering a 429 response. * Included a manual test to simulate real-world scenarios, ensuring the retry mechanism effectively handles 429 --------- Signed-off-by: Omer Lachish <[email protected]> Co-authored-by: Omer Lachish <[email protected]>
1 parent 1c44950 commit ed4992d

File tree

12 files changed

+247
-59
lines changed

12 files changed

+247
-59
lines changed

databricks-sdk-java/src/main/java/com/databricks/sdk/core/ApiClient.java

Lines changed: 96 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.time.ZonedDateTime;
2222
import java.time.format.DateTimeFormatter;
2323
import java.util.*;
24+
import java.util.function.Function;
2425
import org.slf4j.Logger;
2526
import org.slf4j.LoggerFactory;
2627

@@ -29,55 +30,125 @@
2930
* guessing
3031
*/
3132
public class ApiClient {
33+
public static class Builder {
34+
private Timer timer;
35+
private Function<Void, Map<String, String>> authenticateFunc;
36+
private Function<Void, String> getHostFunc;
37+
private Function<Void, String> getAuthTypeFunc;
38+
private Integer debugTruncateBytes;
39+
private HttpClient httpClient;
40+
private String accountId;
41+
private RetryStrategyPicker retryStrategyPicker;
42+
private boolean isDebugHeaders;
43+
44+
public Builder withDatabricksConfig(DatabricksConfig config) {
45+
this.authenticateFunc = v -> config.authenticate();
46+
this.getHostFunc = v -> config.getHost();
47+
this.getAuthTypeFunc = v -> config.getAuthType();
48+
this.httpClient = config.getHttpClient();
49+
this.debugTruncateBytes = config.getDebugTruncateBytes();
50+
this.accountId = config.getAccountId();
51+
this.retryStrategyPicker = new RequestBasedRetryStrategyPicker(config.getHost());
52+
this.isDebugHeaders = config.isDebugHeaders();
53+
54+
return this;
55+
}
56+
57+
public Builder withTimer(Timer timer) {
58+
this.timer = timer;
59+
return this;
60+
}
61+
62+
public Builder withAuthenticateFunc(Function<Void, Map<String, String>> authenticateFunc) {
63+
this.authenticateFunc = authenticateFunc;
64+
return this;
65+
}
66+
67+
public Builder withGetHostFunc(Function<Void, String> getHostFunc) {
68+
this.getHostFunc = getHostFunc;
69+
return this;
70+
}
71+
72+
public Builder withGetAuthTypeFunc(Function<Void, String> getAuthTypeFunc) {
73+
this.getAuthTypeFunc = getAuthTypeFunc;
74+
return this;
75+
}
76+
77+
public Builder withHttpClient(HttpClient httpClient) {
78+
this.httpClient = httpClient;
79+
return this;
80+
}
81+
82+
public Builder withRetryStrategyPicker(RetryStrategyPicker retryStrategyPicker) {
83+
this.retryStrategyPicker = retryStrategyPicker;
84+
return this;
85+
}
86+
87+
public ApiClient build() {
88+
return new ApiClient(this);
89+
}
90+
}
91+
3292
private static final Logger LOG = LoggerFactory.getLogger(ApiClient.class);
3393

3494
private final int maxAttempts;
3595

3696
private final ObjectMapper mapper;
3797

38-
private final DatabricksConfig config;
39-
4098
private final Random random;
4199

42100
private final HttpClient httpClient;
43101
private final BodyLogger bodyLogger;
44102
private final RetryStrategyPicker retryStrategyPicker;
45103
private final Timer timer;
104+
private final Function<Void, Map<String, String>> authenticateFunc;
105+
private final Function<Void, String> getHostFunc;
106+
private final Function<Void, String> getAuthTypeFunc;
107+
private final String accountId;
108+
private final boolean isDebugHeaders;
46109
private static final String RETRY_AFTER_HEADER = "retry-after";
47110

48111
public ApiClient() {
49112
this(ConfigLoader.getDefault());
50113
}
51114

52115
public String configuredAccountID() {
53-
return config.getAccountId();
116+
return accountId;
54117
}
55118

56119
public ApiClient(DatabricksConfig config) {
57120
this(config, new SystemTimer());
58121
}
59122

60123
public ApiClient(DatabricksConfig config, Timer timer) {
61-
this.config = config;
62-
config.resolve();
63-
64-
Integer rateLimit = config.getRateLimit();
65-
if (rateLimit == null) {
66-
rateLimit = 15;
67-
}
68-
69-
Integer debugTruncateBytes = config.getDebugTruncateBytes();
124+
this(new Builder().withDatabricksConfig(config.resolve()).withTimer(timer));
125+
}
126+
127+
private ApiClient(Builder builder) {
128+
this.timer = builder.timer != null ? builder.timer : new SystemTimer();
129+
this.authenticateFunc =
130+
builder.authenticateFunc != null
131+
? builder.authenticateFunc
132+
: v -> new HashMap<String, String>();
133+
this.getHostFunc = builder.getHostFunc != null ? builder.getHostFunc : v -> "";
134+
this.getAuthTypeFunc = builder.getAuthTypeFunc != null ? builder.getAuthTypeFunc : v -> "";
135+
this.httpClient = builder.httpClient;
136+
this.accountId = builder.accountId;
137+
this.retryStrategyPicker =
138+
builder.retryStrategyPicker != null
139+
? builder.retryStrategyPicker
140+
: new RequestBasedRetryStrategyPicker(this.getHostFunc.apply(null));
141+
this.isDebugHeaders = builder.isDebugHeaders;
142+
143+
Integer debugTruncateBytes = builder.debugTruncateBytes;
70144
if (debugTruncateBytes == null) {
71145
debugTruncateBytes = 96;
72146
}
73147

74148
maxAttempts = 4;
75149
mapper = SerDeUtils.createMapper();
76150
random = new Random();
77-
httpClient = config.getHttpClient();
78151
bodyLogger = new BodyLogger(mapper, 1024, debugTruncateBytes);
79-
retryStrategyPicker = new RequestBasedRetryStrategyPicker(this.config);
80-
this.timer = timer;
81152
}
82153

83154
private static <I> void setQuery(Request in, I entity) {
@@ -203,7 +274,7 @@ private <I> Request prepareBaseRequest(String method, String path, I in)
203274
InputStream body = (InputStream) in;
204275
return new Request(method, path, body);
205276
} else {
206-
String body = serialize(in);
277+
String body = (in instanceof String) ? (String) in : serialize(in);
207278
return new Request(method, path, body);
208279
}
209280
}
@@ -245,15 +316,19 @@ private Response executeInner(Request in, String path) {
245316
Response out = null;
246317

247318
// Authenticate the request. Failures should not be retried.
248-
in.withHeaders(config.authenticate());
319+
in.withHeaders(authenticateFunc.apply(null));
249320

250321
// Prepend host to URL only after config.authenticate().
251322
// This call may configure the host (e.g. in case of notebook native auth).
252-
in.withUrl(config.getHost() + path);
323+
in.withUrl(getHostFunc.apply(null) + path);
253324

254325
// Set User-Agent with auth type info, which is available only
255326
// after the first invocation to config.authenticate()
256-
String userAgent = String.format("%s auth/%s", UserAgent.asString(), config.getAuthType());
327+
String userAgent = UserAgent.asString();
328+
String authType = getAuthTypeFunc.apply(null);
329+
if (authType != "") {
330+
userAgent += String.format(" auth/%s", authType);
331+
}
257332
in.withHeader("User-Agent", userAgent);
258333

259334
// Make the request, catching any exceptions, as we may want to retry.
@@ -347,9 +422,9 @@ private String makeLogRecord(Request in, Response out) {
347422
StringBuilder sb = new StringBuilder();
348423
sb.append("> ");
349424
sb.append(in.getRequestLine());
350-
if (config.isDebugHeaders()) {
425+
if (this.isDebugHeaders) {
351426
sb.append("\n * Host: ");
352-
sb.append(config.getHost());
427+
sb.append(this.getHostFunc.apply(null));
353428
in.getHeaders()
354429
.forEach((header, value) -> sb.append(String.format("\n * %s: %s", header, value)));
355430
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -581,12 +581,15 @@ private OpenIDConnectEndpoints fetchDefaultOidcEndpoints() throws IOException {
581581
return new OpenIDConnectEndpoints(prefix + "/v1/token", prefix + "/v1/authorize");
582582
}
583583

584-
String oidcEndpoint = getHost() + "/oidc/.well-known/oauth-authorization-server";
585-
Response resp = getHttpClient().execute(new Request("GET", oidcEndpoint));
586-
if (resp.getStatusCode() != 200) {
587-
return null;
588-
}
589-
return new ObjectMapper().readValue(resp.getBody(), OpenIDConnectEndpoints.class);
584+
ApiClient apiClient =
585+
new ApiClient.Builder()
586+
.withHttpClient(getHttpClient())
587+
.withGetHostFunc(v -> getHost())
588+
.build();
589+
return apiClient.GET(
590+
"/oidc/.well-known/oauth-authorization-server",
591+
OpenIDConnectEndpoints.class,
592+
new HashMap<>());
590593
}
591594

592595
@Override

databricks-sdk-java/src/main/java/com/databricks/sdk/core/http/FormRequest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ public FormRequest(String url, Map<String, String> form) {
88
}
99

1010
public FormRequest(String method, String url, Map<String, String> form) {
11-
super(method, url, mapToQuery(wrapValuesInList(form)));
11+
super(method, url, wrapValuesInList(form));
1212
withHeader("Content-Type", "application/x-www-form-urlencoded");
1313
}
1414

15-
static Map<String, List<String>> wrapValuesInList(Map<String, String> map) {
15+
public static String wrapValuesInList(Map<String, String> map) {
1616
Map<String, List<String>> m = new LinkedHashMap<>();
1717
for (Map.Entry<String, String> entry : map.entrySet()) {
1818
m.put(entry.getKey(), Collections.singletonList(entry.getValue()));
1919
}
20-
return m;
20+
return mapToQuery(m);
2121
}
2222
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OpenIDConnectEndpoints.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
*/
1212
@JsonIgnoreProperties(ignoreUnknown = true)
1313
public class OpenIDConnectEndpoints {
14+
@JsonProperty("token_endpoint")
1415
private String tokenEndpoint;
1516

17+
@JsonProperty("authorization_endpoint")
1618
private String authorizationEndpoint;
1719

20+
public OpenIDConnectEndpoints() {}
21+
1822
public OpenIDConnectEndpoints(
1923
@JsonProperty("token_endpoint") String tokenEndpoint,
2024
@JsonProperty("authorization_endpoint") String authorizationEndpoint)

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
package com.databricks.sdk.core.oauth;
22

3+
import com.databricks.sdk.core.ApiClient;
34
import com.databricks.sdk.core.DatabricksException;
45
import com.databricks.sdk.core.http.FormRequest;
56
import com.databricks.sdk.core.http.HttpClient;
6-
import com.databricks.sdk.core.http.Response;
7-
import com.fasterxml.jackson.databind.ObjectMapper;
8-
import java.io.IOException;
97
import java.time.LocalDateTime;
108
import java.time.temporal.ChronoUnit;
119
import java.util.Base64;
@@ -62,17 +60,19 @@ protected static Token retrieveToken(
6260
headers.put(HttpHeaders.AUTHORIZATION, authHeaderValue);
6361
break;
6462
}
65-
FormRequest req = new FormRequest(tokenUrl, params);
66-
req.withHeaders(headers);
63+
headers.put("Content-Type", "application/x-www-form-urlencoded");
6764
try {
68-
Response rawResp = hc.execute(req);
69-
OAuthResponse resp = new ObjectMapper().readValue(rawResp.getBody(), OAuthResponse.class);
65+
ApiClient apiClient = new ApiClient.Builder().withHttpClient(hc).build();
66+
67+
OAuthResponse resp =
68+
apiClient.POST(
69+
tokenUrl, FormRequest.wrapValuesInList(params), OAuthResponse.class, headers);
7070
if (resp.getErrorCode() != null) {
7171
throw new IllegalArgumentException(resp.getErrorCode() + ": " + resp.getErrorSummary());
7272
}
7373
LocalDateTime expiry = LocalDateTime.now().plus(resp.getExpiresIn(), ChronoUnit.SECONDS);
7474
return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry);
75-
} catch (IOException e) {
75+
} catch (Exception e) {
7676
throw new DatabricksException("Failed to refresh credentials: " + e.getMessage(), e);
7777
}
7878
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/retry/RequestBasedRetryStrategyPicker.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package com.databricks.sdk.core.retry;
22

3-
import com.databricks.sdk.core.DatabricksConfig;
43
import com.databricks.sdk.core.http.Request;
54
import java.util.AbstractMap;
65
import java.util.Arrays;
@@ -37,15 +36,14 @@ public class RequestBasedRetryStrategyPicker implements RetryStrategyPicker {
3736
private static final IdempotentRequestRetryStrategy IDEMPOTENT_RETRY_STRATEGY =
3837
new IdempotentRequestRetryStrategy();
3938

40-
public RequestBasedRetryStrategyPicker(DatabricksConfig config) {
39+
public RequestBasedRetryStrategyPicker(String host) {
4140
this.idempotentRequestsPattern =
4241
IDEMPOTENT_REQUESTS.stream()
4342
.map(
4443
request ->
4544
new AbstractMap.SimpleEntry<>(
4645
request.getMethod(),
47-
Pattern.compile(
48-
config.getHost() + request.getUrl(), Pattern.CASE_INSENSITIVE)))
46+
Pattern.compile(host + request.getUrl(), Pattern.CASE_INSENSITIVE)))
4947
.collect(Collectors.toList());
5048
}
5149

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package com.databricks.sdk.benchmark;
2+
3+
import static org.junit.jupiter.api.Assertions.*;
4+
5+
import java.util.concurrent.*;
6+
7+
/*
8+
This test executes the authentication workflow 200 times concurrently and verifies that all 200 runs complete successfully.
9+
It is designed to address a previously observed issue where multiple SDK operations needed to authenticate, causing the OIDC endpoints to rate-limit the requests.
10+
Now that these endpoints are configured to retry upon receiving a 429 error, the test runs successfully.
11+
However, since this test generates a large number of requests, it should be run manually rather than being included in CI processes.
12+
*/
13+
/*
14+
public class DatabricksAuthLoadTest implements GitHubUtils, ConfigResolving {
15+
16+
@Test
17+
@Disabled
18+
public void testConcurrentConfigBasicAuthAttrs() throws Exception {
19+
int numThreads = 200;
20+
ExecutorService executorService = Executors.newFixedThreadPool(numThreads);
21+
List<Future<Boolean>> futures = new ArrayList<>();
22+
int successCount = 0;
23+
int failureCount = 0;
24+
25+
Callable<Boolean> task =
26+
() -> {
27+
try {
28+
DatabricksConfig config =
29+
new DatabricksConfig()
30+
.setHost("https://dbc-bb03964f-3f59.cloud.databricks.com")
31+
.setClientId("<<REDACTED>>")
32+
.setClientSecret("<<REDACTED>>");
33+
34+
config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
35+
config.authenticate();
36+
37+
assertEquals("oauth-m2m", config.getAuthType());
38+
39+
return true;
40+
} catch (Exception e) {
41+
System.err.println(
42+
"DatabricksException occurred in thread " + Thread.currentThread().getName());
43+
e.printStackTrace();
44+
return false;
45+
}
46+
};
47+
48+
for (int i = 0; i < numThreads; i++) {
49+
futures.add(executorService.submit(task));
50+
}
51+
52+
executorService.shutdown();
53+
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
54+
executorService.shutdownNow();
55+
}
56+
57+
for (Future<Boolean> future : futures) {
58+
if (future.get()) {
59+
successCount++;
60+
} else {
61+
failureCount++;
62+
}
63+
}
64+
65+
// Log the results
66+
System.out.println("Number of successful threads: " + successCount);
67+
System.out.println("Number of failed threads: " + failureCount);
68+
69+
// Optionally, you can assert that there were no failures
70+
assertEquals(0, failureCount);
71+
}
72+
}
73+
*/

0 commit comments

Comments
 (0)