Skip to content

Commit 5da8113

Browse files
authored
Fixes #4153: Handling OpenAI 429's gracefully (#4284)
* Fixes #4153: Handling OpenAI 429's gracefully * cleanup * fix tests
1 parent db54aa4 commit 5da8113

File tree

4 files changed

+209
-11
lines changed

4 files changed

+209
-11
lines changed

docs/asciidoc/modules/ROOT/pages/ml/openai.adoc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ If present, they take precedence over the analogous APOC configs.
3737
| jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response.
3838
The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure.
3939
| failOnError | If true (default), the procedure fails in case of empty, blank or null input
40+
| enableBackOffRetries | If set to true, enables the backoff retry strategy for handling failures. (default: false)
41+
| backOffRetries | Sets the maximum number of retry attempts before the operation throws an exception. (default: 5)
42+
| exponentialBackoff | If set to true, applies an exponential progression to the wait time between retries. If set to false, the wait time increases linearly. (default: false)
4043
|===
4144

4245

extended/src/main/java/apoc/ml/OpenAI.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import apoc.ApocConfig;
44
import apoc.Extended;
55
import apoc.result.MapResult;
6+
import apoc.util.ExtendedUtil;
67
import apoc.util.JsonUtil;
78
import apoc.util.Util;
89
import com.fasterxml.jackson.core.JsonProcessingException;
@@ -36,6 +37,9 @@ public class OpenAI {
3637
public static final String PATH_CONF_KEY = "path";
3738
public static final String GPT_4O_MODEL = "gpt-4o";
3839
public static final String FAIL_ON_ERROR_CONF = "failOnError";
40+
public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries";
41+
public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff";
42+
public static final String BACK_OFF_RETRIES_CONF_KEY = "backOffRetries";
3943

4044
@Context
4145
public ApocConfig apocConfig;
@@ -59,6 +63,9 @@ public EmbeddingResult(long index, String text, List<Double> embedding) {
5963

6064
static Stream<Object> executeRequest(String apiKey, Map<String, Object> configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException {
6165
apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey));
66+
boolean enableBackOffRetries = Util.toBoolean( configuration.get(ENABLE_BACK_OFF_RETRIES_CONF_KEY) );
67+
Integer backOffRetries = Util.toInteger(configuration.getOrDefault(BACK_OFF_RETRIES_CONF_KEY, 5));
68+
boolean exponentialBackoff = Util.toBoolean( configuration.get(ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY) );
6269
if (apiKey == null || apiKey.isBlank())
6370
throw new IllegalArgumentException("API Key must not be empty");
6471

@@ -78,7 +85,7 @@ static Stream<Object> executeRequest(String apiKey, Map<String, Object> configur
7885
path = (String) configuration.getOrDefault(PATH_CONF_KEY, path);
7986
OpenAIRequestHandler apiType = type.get();
8087

81-
jsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath);
88+
String sJsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath);
8289
headers.put("Content-Type", "application/json");
8390
apiType.addApiKey(headers, apiKey);
8491

@@ -88,7 +95,14 @@ static Stream<Object> executeRequest(String apiKey, Map<String, Object> configur
8895
// eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model
8996
// therefore is better to join the not-empty path pieces
9097
var url = apiType.getFullUrl(path, configuration, apocConfig);
91-
return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of(), urlAccessChecker);
98+
return ExtendedUtil.withBackOffRetries(
99+
() -> JsonUtil.loadJson(url, headers, payload, sJsonPath, true, List.of(), urlAccessChecker),
100+
enableBackOffRetries, backOffRetries, exponentialBackoff,
101+
exception -> {
102+
if(!exception.getMessage().contains("429"))
103+
throw new RuntimeException(exception);
104+
}
105+
);
92106
}
93107

94108
private static void handleAPIProvider(OpenAIRequestHandler.Type type,

extended/src/main/java/apoc/util/ExtendedUtil.java

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,9 @@
3030
import java.time.ZoneId;
3131
import java.time.ZonedDateTime;
3232
import java.time.temporal.TemporalAccessor;
33-
import java.util.ArrayList;
34-
import java.util.Arrays;
35-
import java.util.Collection;
36-
import java.util.HashMap;
37-
import java.util.HashSet;
38-
import java.util.List;
39-
import java.util.Map;
40-
import java.util.Set;
33+
import java.util.*;
34+
import java.util.function.Consumer;
35+
import java.util.function.Supplier;
4136
import java.util.stream.Collectors;
4237
import java.util.stream.LongStream;
4338
import java.util.stream.Stream;
@@ -353,5 +348,59 @@ public static float[] listOfNumbersToFloatArray(List<? extends Number> embedding
353348
}
354349
return floats;
355350
}
356-
351+
352+
public static <T> T withBackOffRetries(
353+
Supplier<T> func,
354+
boolean retry,
355+
int backoffRetry,
356+
boolean exponential,
357+
Consumer<Exception> exceptionHandler
358+
) {
359+
T result;
360+
backoffRetry = backoffRetry < 1
361+
? 5
362+
: backoffRetry;
363+
int countDown = backoffRetry;
364+
exceptionHandler = Objects.requireNonNullElse(exceptionHandler, exe -> {});
365+
while (true) {
366+
try {
367+
result = func.get();
368+
break;
369+
} catch (Exception e) {
370+
if(!retry || countDown < 1) throw e;
371+
exceptionHandler.accept(e);
372+
countDown--;
373+
long delay = getDelay(backoffRetry, countDown, exponential);
374+
backoffSleep(delay);
375+
}
376+
}
377+
return result;
378+
}
379+
380+
private static void backoffSleep(long millis){
381+
sleep(millis, "Operation interrupted during backoff");
382+
}
383+
384+
public static void sleep(long millis, String interruptedMessage) {
385+
try {
386+
Thread.sleep(millis);
387+
} catch (InterruptedException ie) {
388+
Thread.currentThread().interrupt();
389+
throw new RuntimeException(interruptedMessage, ie);
390+
}
391+
}
392+
393+
private static long getDelay(int backoffRetry, int countDown, boolean exponential) {
394+
int backOffTime = backoffRetry - countDown;
395+
long sleepMultiplier = exponential ?
396+
(long) Math.pow(2, backOffTime) : // Exponential retry progression
397+
backOffTime; // Linear retry progression
398+
return Math.min(
399+
Duration.ofSeconds(1)
400+
.multipliedBy(sleepMultiplier)
401+
.toMillis(),
402+
Duration.ofSeconds(30).toMillis() // Max 30s
403+
);
404+
}
405+
357406
}
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package apoc.util;
2+
3+
import org.junit.Test;
4+
5+
import static org.junit.Assert.*;
6+
import static org.junit.Assert.assertTrue;
7+
8+
public class ExtendedUtilTest {
9+
10+
private static int i = 0;
11+
12+
@Test
13+
public void testWithLinearBackOffRetriesWithSuccess() {
14+
i = 0;
15+
long start = System.currentTimeMillis();
16+
int result = ExtendedUtil.withBackOffRetries(
17+
this::testFunction,
18+
true,
19+
-1, // test backoffRetry default value -> 5
20+
false,
21+
runEx -> {
22+
if(!runEx.getMessage().contains("Expected"))
23+
throw new RuntimeException("Some Bad News...");
24+
}
25+
);
26+
long time = System.currentTimeMillis() - start;
27+
28+
assertEquals(4, result);
29+
30+
// The method will attempt to execute the operation with a linear backoff strategy,
31+
// sleeping for 1 second, 2 seconds, and 3 seconds between retries.
32+
// This results in a total wait time of 6 seconds (1s + 2s + 3s + 4s) if the operation succeeds on the third attempt,
33+
// leading to an approximate execution time of 6 seconds.
34+
assertTrue("Current time is: " + time,
35+
time > 9000 && time < 11000);
36+
}
37+
38+
@Test
39+
public void testWithExponentialBackOffRetriesWithSuccess() {
40+
i = 0;
41+
long start = System.currentTimeMillis();
42+
int result = ExtendedUtil.withBackOffRetries(
43+
this::testFunction,
44+
true,
45+
0, // test backoffRetry default value -> 5
46+
true,
47+
runEx -> {}
48+
);
49+
long time = System.currentTimeMillis() - start;
50+
51+
assertEquals(4, result);
52+
53+
// The method will attempt to execute the operation with an exponential backoff strategy,
54+
// sleeping for 2 second, 4 seconds, and 8 seconds between retries.
55+
// This results in a total wait time of 30 seconds (2s + 4s + 8s + 16s) if the operation succeeds on the third attempt,
56+
// leading to an approximate execution time of 14 seconds.
57+
assertTrue("Current time is: " + time,
58+
time > 29000 && time < 31000);
59+
}
60+
61+
@Test
62+
public void testBackOffRetriesWithError() {
63+
i = 0;
64+
long start = System.currentTimeMillis();
65+
assertThrows(
66+
RuntimeException.class,
67+
() -> ExtendedUtil.withBackOffRetries(
68+
this::testFunction,
69+
true,
70+
2,
71+
false,
72+
runEx -> {}
73+
)
74+
);
75+
long time = System.currentTimeMillis() - start;
76+
77+
// The method is configured to retry the operation twice.
78+
// So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an exception.
79+
// Resulting in an approximate execution time of 3 seconds.
80+
assertTrue("Current time is: " + time,
81+
time > 2000 && time < 4000);
82+
}
83+
84+
@Test
85+
public void testBackOffRetriesWithErrorAndExponential() {
86+
i = 0;
87+
long start = System.currentTimeMillis();
88+
assertThrows(
89+
RuntimeException.class,
90+
() -> ExtendedUtil.withBackOffRetries(
91+
this::testFunction,
92+
true,
93+
2,
94+
true,
95+
runEx -> {}
96+
)
97+
);
98+
long time = System.currentTimeMillis() - start;
99+
100+
// The method is configured to retry the operation twice.
101+
// So, it will make two extra-attempts, waiting for 2 second and 4 seconds before failing and throwing an exception.
102+
// Resulting in an approximate execution time of 6 seconds.
103+
assertTrue("Current time is: " + time,
104+
time > 5000 && time < 7000);
105+
}
106+
107+
@Test
108+
public void testWithoutBackOffRetriesWithError() {
109+
i = 0;
110+
assertThrows(
111+
RuntimeException.class,
112+
() -> ExtendedUtil.withBackOffRetries(
113+
this::testFunction,
114+
false, 30,
115+
false,
116+
runEx -> {}
117+
)
118+
);
119+
120+
// Retry strategy is not active and the testFunction is executed only once by raising an exception.
121+
assertEquals(1, i);
122+
}
123+
124+
private int testFunction() {
125+
if (i == 4) {
126+
return i;
127+
}
128+
i++;
129+
throw new RuntimeException("Expected i not equal to 4");
130+
}
131+
132+
}

0 commit comments

Comments
 (0)