@@ -5,6 +5,9 @@ import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo
55import com.github.tomakehurst.wiremock.junit5.WireMockTest
66import com.github.tomakehurst.wiremock.stubbing.Scenario
77import com.openai.client.okhttp.OkHttpClient
8+ import com.openai.core.RequestOptions
9+ import java.io.InputStream
10+ import java.util.concurrent.CompletableFuture
811import org.assertj.core.api.Assertions.assertThat
912import org.junit.jupiter.api.BeforeEach
1013import org.junit.jupiter.params.ParameterizedTest
@@ -13,11 +16,49 @@ import org.junit.jupiter.params.provider.ValueSource
1316@WireMockTest
1417internal class RetryingHttpClientTest {
1518
19+ private var openResponseCount = 0
1620 private lateinit var httpClient: HttpClient
1721
1822 @BeforeEach
1923 fun beforeEach (wmRuntimeInfo : WireMockRuntimeInfo ) {
20- httpClient = OkHttpClient .builder().baseUrl(wmRuntimeInfo.httpBaseUrl).build()
24+ val okHttpClient = OkHttpClient .builder().baseUrl(wmRuntimeInfo.httpBaseUrl).build()
25+ httpClient =
26+ object : HttpClient {
27+ override fun execute (
28+ request : HttpRequest ,
29+ requestOptions : RequestOptions
30+ ): HttpResponse = trackClose(okHttpClient.execute(request, requestOptions))
31+
32+ override fun executeAsync (
33+ request : HttpRequest ,
34+ requestOptions : RequestOptions
35+ ): CompletableFuture <HttpResponse > =
36+ okHttpClient.executeAsync(request, requestOptions).thenApply { trackClose(it) }
37+
38+ override fun close () = okHttpClient.close()
39+
40+ private fun trackClose (response : HttpResponse ): HttpResponse {
41+ openResponseCount++
42+ return object : HttpResponse {
43+ private var isClosed = false
44+
45+ override fun statusCode (): Int = response.statusCode()
46+
47+ override fun headers (): Headers = response.headers()
48+
49+ override fun body (): InputStream = response.body()
50+
51+ override fun close () {
52+ response.close()
53+ if (isClosed) {
54+ return
55+ }
56+ openResponseCount--
57+ isClosed = true
58+ }
59+ }
60+ }
61+ }
2162 resetAllScenarios()
2263 }
2364
@@ -35,6 +76,7 @@ internal class RetryingHttpClientTest {
3576
3677 assertThat(response.statusCode()).isEqualTo(200 )
3778 verify(1 , postRequestedFor(urlPathEqualTo(" /something" )))
79+ assertNoResponseLeaks()
3880 }
3981
4082 @ParameterizedTest
@@ -60,6 +102,7 @@ internal class RetryingHttpClientTest {
60102
61103 assertThat(response.statusCode()).isEqualTo(200 )
62104 verify(1 , postRequestedFor(urlPathEqualTo(" /something" )))
105+ assertNoResponseLeaks()
63106 }
64107
65108 @ParameterizedTest
@@ -116,6 +159,7 @@ internal class RetryingHttpClientTest {
116159 postRequestedFor(urlPathEqualTo(" /something" ))
117160 .withHeader(" x-stainless-retry-count" , equalTo(" 2" ))
118161 )
162+ assertNoResponseLeaks()
119163 }
120164
121165 @ParameterizedTest
@@ -156,6 +200,7 @@ internal class RetryingHttpClientTest {
156200 postRequestedFor(urlPathEqualTo(" /something" ))
157201 .withHeader(" x-stainless-retry-count" , equalTo(" 42" ))
158202 )
203+ assertNoResponseLeaks()
159204 }
160205
161206 @ParameterizedTest
@@ -186,8 +231,13 @@ internal class RetryingHttpClientTest {
186231
187232 assertThat(response.statusCode()).isEqualTo(200 )
188233 verify(2 , postRequestedFor(urlPathEqualTo(" /something" )))
234+ assertNoResponseLeaks()
189235 }
190236
191237 private fun HttpClient.execute (request : HttpRequest , async : Boolean ): HttpResponse =
192238 if (async) executeAsync(request).get() else execute(request)
239+
240+ // When retrying, all failed responses should be closed. Only the final returned response should
241+ // be open.
242+ private fun assertNoResponseLeaks () = assertThat(openResponseCount).isEqualTo(1 )
193243}
0 commit comments