Skip to content

Commit a13f967

Browse files
stainless-app[bot]stainless-bot
authored andcommitted
fix(client): don't leak responses when retrying (#185)
1 parent e453789 commit a13f967

File tree

2 files changed

+71
-17
lines changed

2 files changed

+71
-17
lines changed

openai-java-core/src/main/kotlin/com/openai/core/http/RetryingHttpClient.kt

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,17 @@ private constructor(
5757
}
5858

5959
response
60-
} catch (t: Throwable) {
61-
if (++retries > maxRetries || !shouldRetry(t)) {
62-
throw t
60+
} catch (throwable: Throwable) {
61+
if (++retries > maxRetries || !shouldRetry(throwable)) {
62+
throw throwable
6363
}
6464

6565
null
6666
}
6767

6868
val backoffMillis = getRetryBackoffMillis(retries, response)
69+
// All responses must be closed, so close the failed one before retrying.
70+
response?.close()
6971
Thread.sleep(backoffMillis.toMillis())
7072
}
7173
}
@@ -113,6 +115,8 @@ private constructor(
113115
}
114116

115117
val backoffMillis = getRetryBackoffMillis(retries, response)
118+
// All responses must be closed, so close the failed one before retrying.
119+
response?.close()
116120
return sleepAsync(backoffMillis.toMillis()).thenCompose {
117121
executeWithRetries(requestWithRetryCount, requestOptions)
118122
}
@@ -223,23 +227,23 @@ private constructor(
223227
return Duration.ofNanos((TimeUnit.SECONDS.toNanos(1) * backoffSeconds * jitter).toLong())
224228
}
225229

226-
private fun sleepAsync(millis: Long): CompletableFuture<Void> {
227-
val future = CompletableFuture<Void>()
228-
TIMER.schedule(
229-
object : TimerTask() {
230-
override fun run() {
231-
future.complete(null)
232-
}
233-
},
234-
millis
235-
)
236-
return future
237-
}
238-
239230
companion object {
240231

241232
private val TIMER = Timer("RetryingHttpClient", true)
242233

234+
private fun sleepAsync(millis: Long): CompletableFuture<Void> {
235+
val future = CompletableFuture<Void>()
236+
TIMER.schedule(
237+
object : TimerTask() {
238+
override fun run() {
239+
future.complete(null)
240+
}
241+
},
242+
millis
243+
)
244+
return future
245+
}
246+
243247
@JvmStatic fun builder() = Builder()
244248
}
245249

openai-java-core/src/test/kotlin/com/openai/core/http/RetryingHttpClientTest.kt

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo
55
import com.github.tomakehurst.wiremock.junit5.WireMockTest
66
import com.github.tomakehurst.wiremock.stubbing.Scenario
77
import com.openai.client.okhttp.OkHttpClient
8+
import com.openai.core.RequestOptions
9+
import java.io.InputStream
10+
import java.util.concurrent.CompletableFuture
811
import org.assertj.core.api.Assertions.assertThat
912
import org.junit.jupiter.api.BeforeEach
1013
import org.junit.jupiter.params.ParameterizedTest
@@ -13,11 +16,49 @@ import org.junit.jupiter.params.provider.ValueSource
1316
@WireMockTest
1417
internal class RetryingHttpClientTest {
1518

19+
private var openResponseCount = 0
1620
private lateinit var httpClient: HttpClient
1721

1822
@BeforeEach
1923
fun beforeEach(wmRuntimeInfo: WireMockRuntimeInfo) {
20-
httpClient = OkHttpClient.builder().baseUrl(wmRuntimeInfo.httpBaseUrl).build()
24+
val okHttpClient = OkHttpClient.builder().baseUrl(wmRuntimeInfo.httpBaseUrl).build()
25+
httpClient =
26+
object : HttpClient {
27+
override fun execute(
28+
request: HttpRequest,
29+
requestOptions: RequestOptions
30+
): HttpResponse = trackClose(okHttpClient.execute(request, requestOptions))
31+
32+
override fun executeAsync(
33+
request: HttpRequest,
34+
requestOptions: RequestOptions
35+
): CompletableFuture<HttpResponse> =
36+
okHttpClient.executeAsync(request, requestOptions).thenApply { trackClose(it) }
37+
38+
override fun close() = okHttpClient.close()
39+
40+
private fun trackClose(response: HttpResponse): HttpResponse {
41+
openResponseCount++
42+
return object : HttpResponse {
43+
private var isClosed = false
44+
45+
override fun statusCode(): Int = response.statusCode()
46+
47+
override fun headers(): Headers = response.headers()
48+
49+
override fun body(): InputStream = response.body()
50+
51+
override fun close() {
52+
response.close()
53+
if (isClosed) {
54+
return
55+
}
56+
openResponseCount--
57+
isClosed = true
58+
}
59+
}
60+
}
61+
}
2162
resetAllScenarios()
2263
}
2364

@@ -35,6 +76,7 @@ internal class RetryingHttpClientTest {
3576

3677
assertThat(response.statusCode()).isEqualTo(200)
3778
verify(1, postRequestedFor(urlPathEqualTo("/something")))
79+
assertNoResponseLeaks()
3880
}
3981

4082
@ParameterizedTest
@@ -60,6 +102,7 @@ internal class RetryingHttpClientTest {
60102

61103
assertThat(response.statusCode()).isEqualTo(200)
62104
verify(1, postRequestedFor(urlPathEqualTo("/something")))
105+
assertNoResponseLeaks()
63106
}
64107

65108
@ParameterizedTest
@@ -116,6 +159,7 @@ internal class RetryingHttpClientTest {
116159
postRequestedFor(urlPathEqualTo("/something"))
117160
.withHeader("x-stainless-retry-count", equalTo("2"))
118161
)
162+
assertNoResponseLeaks()
119163
}
120164

121165
@ParameterizedTest
@@ -156,6 +200,7 @@ internal class RetryingHttpClientTest {
156200
postRequestedFor(urlPathEqualTo("/something"))
157201
.withHeader("x-stainless-retry-count", equalTo("42"))
158202
)
203+
assertNoResponseLeaks()
159204
}
160205

161206
@ParameterizedTest
@@ -186,8 +231,13 @@ internal class RetryingHttpClientTest {
186231

187232
assertThat(response.statusCode()).isEqualTo(200)
188233
verify(2, postRequestedFor(urlPathEqualTo("/something")))
234+
assertNoResponseLeaks()
189235
}
190236

191237
private fun HttpClient.execute(request: HttpRequest, async: Boolean): HttpResponse =
192238
if (async) executeAsync(request).get() else execute(request)
239+
240+
// When retrying, all failed responses should be closed. Only the final returned response should
241+
// be open.
242+
private fun assertNoResponseLeaks() = assertThat(openResponseCount).isEqualTo(1)
193243
}

0 commit comments

Comments
 (0)