=
+ apply {
+ asyncStreamResponse.subscribe(handler, executor)
+ }
+
+ override fun close() = asyncStreamResponse.close()
+}
diff --git a/openai-java-core/src/main/kotlin/com/openai/credential/BearerTokenCredential.kt b/openai-java-core/src/main/kotlin/com/openai/credential/BearerTokenCredential.kt
deleted file mode 100644
index 6da402a95..000000000
--- a/openai-java-core/src/main/kotlin/com/openai/credential/BearerTokenCredential.kt
+++ /dev/null
@@ -1,36 +0,0 @@
-package com.openai.credential
-
-import java.util.function.Supplier
-
-/**
- * A credential that provides a bearer token.
- *
- *
- * If you are using the OpenAI API, you need to provide a bearer token for authentication. All API
- * requests should include your API key in an Authorization HTTP header as follows: "Authorization:
- * Bearer OPENAI_API_KEY"
- *
- * Two ways to provide the token:
- *
- * 1. Provide the token directly, 'BearerTokenCredential.create(String)'. The method
- * 'ClientOptions.apiKey(String)' is a wrapper for this. 2. Provide a supplier that
- * provides the token, 'BearerTokenCredential.create(Supplier)'.
- *
- *
- * @param tokenSupplier a supplier that provides the token.
- * @see OpenAI
- * Authentication
- */
-class BearerTokenCredential private constructor(private val tokenSupplier: Supplier) :
- Credential {
-
- companion object {
- @JvmStatic fun create(token: String): Credential = BearerTokenCredential { token }
-
- @JvmStatic
- fun create(tokenSupplier: Supplier): Credential =
- BearerTokenCredential(tokenSupplier)
- }
-
- fun token(): String = tokenSupplier.get()
-}
diff --git a/openai-java-core/src/main/kotlin/com/openai/credential/Credential.kt b/openai-java-core/src/main/kotlin/com/openai/credential/Credential.kt
deleted file mode 100644
index f43ab84c4..000000000
--- a/openai-java-core/src/main/kotlin/com/openai/credential/Credential.kt
+++ /dev/null
@@ -1,4 +0,0 @@
-package com.openai.credential
-
-/** An interface that represents a credential. */
-interface Credential
diff --git a/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsync.kt b/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsync.kt
index b3b4ee5a0..ce272a947 100644
--- a/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsync.kt
+++ b/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsync.kt
@@ -5,6 +5,7 @@
package com.openai.services.async
import com.openai.core.RequestOptions
+import com.openai.core.http.AsyncStreamResponse
import com.openai.models.Completion
import com.openai.models.CompletionCreateParams
import java.util.concurrent.CompletableFuture
@@ -17,4 +18,11 @@ interface CompletionServiceAsync {
params: CompletionCreateParams,
requestOptions: RequestOptions = RequestOptions.none()
): CompletableFuture
+
+ /** Creates a completion for the provided prompt and parameters. */
+ @JvmOverloads
+ fun createStreaming(
+ params: CompletionCreateParams,
+ requestOptions: RequestOptions = RequestOptions.none()
+ ): AsyncStreamResponse
}
diff --git a/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsyncImpl.kt b/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsyncImpl.kt
index 4c5ed7df3..d9a137955 100644
--- a/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsyncImpl.kt
+++ b/openai-java-core/src/main/kotlin/com/openai/services/async/CompletionServiceAsyncImpl.kt
@@ -3,13 +3,20 @@
package com.openai.services.async
import com.openai.core.ClientOptions
+import com.openai.core.JsonValue
import com.openai.core.RequestOptions
import com.openai.core.handlers.errorHandler
import com.openai.core.handlers.jsonHandler
+import com.openai.core.handlers.map
+import com.openai.core.handlers.mapJson
+import com.openai.core.handlers.sseHandler
import com.openai.core.handlers.withErrorHandler
+import com.openai.core.http.AsyncStreamResponse
import com.openai.core.http.HttpMethod
import com.openai.core.http.HttpRequest
import com.openai.core.http.HttpResponse.Handler
+import com.openai.core.http.StreamResponse
+import com.openai.core.http.toAsync
import com.openai.core.json
import com.openai.errors.OpenAIError
import com.openai.models.Completion
@@ -52,4 +59,47 @@ constructor(
}
}
}
+
+ private val createStreamingHandler: Handler> =
+ sseHandler(clientOptions.jsonMapper).mapJson().withErrorHandler(errorHandler)
+
+ /** Creates a completion for the provided prompt and parameters. */
+ override fun createStreaming(
+ params: CompletionCreateParams,
+ requestOptions: RequestOptions
+ ): AsyncStreamResponse {
+ val request =
+ HttpRequest.builder()
+ .method(HttpMethod.POST)
+ .addPathSegments("completions")
+ .putAllQueryParams(clientOptions.queryParams)
+ .replaceAllQueryParams(params.getQueryParams())
+ .putAllHeaders(clientOptions.headers)
+ .replaceAllHeaders(params.getHeaders())
+ .body(
+ json(
+ clientOptions.jsonMapper,
+ params
+ .getBody()
+ .toBuilder()
+ .putAdditionalProperty("stream", JsonValue.from(true))
+ .build()
+ )
+ )
+ .build()
+ return clientOptions.httpClient
+ .executeAsync(request, requestOptions)
+ .thenApply { response ->
+ response
+ .let { createStreamingHandler.handle(it) }
+ .let { streamResponse ->
+ if (requestOptions.responseValidation ?: clientOptions.responseValidation) {
+ streamResponse.map { it.validate() }
+ } else {
+ streamResponse
+ }
+ }
+ }
+ .toAsync(clientOptions.streamHandlerExecutor)
+ }
}
diff --git a/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsync.kt b/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsync.kt
index d0214cc6a..ee22a7a6a 100644
--- a/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsync.kt
+++ b/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsync.kt
@@ -5,7 +5,9 @@
package com.openai.services.async.chat
import com.openai.core.RequestOptions
+import com.openai.core.http.AsyncStreamResponse
import com.openai.models.ChatCompletion
+import com.openai.models.ChatCompletionChunk
import com.openai.models.ChatCompletionCreateParams
import java.util.concurrent.CompletableFuture
@@ -22,4 +24,16 @@ interface CompletionServiceAsync {
params: ChatCompletionCreateParams,
requestOptions: RequestOptions = RequestOptions.none()
): CompletableFuture
+
+ /**
+ * Creates a model response for the given chat conversation. Learn more in the
+ * [text generation](https://platform.openai.com/docs/guides/text-generation),
+ * [vision](https://platform.openai.com/docs/guides/vision), and
+ * [audio](https://platform.openai.com/docs/guides/audio) guides.
+ */
+ @JvmOverloads
+ fun createStreaming(
+ params: ChatCompletionCreateParams,
+ requestOptions: RequestOptions = RequestOptions.none()
+ ): AsyncStreamResponse
}
diff --git a/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsyncImpl.kt b/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsyncImpl.kt
index 17ce15ddf..f52b53324 100644
--- a/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsyncImpl.kt
+++ b/openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsyncImpl.kt
@@ -3,18 +3,24 @@
package com.openai.services.async.chat
import com.openai.core.ClientOptions
+import com.openai.core.JsonValue
import com.openai.core.RequestOptions
import com.openai.core.handlers.errorHandler
import com.openai.core.handlers.jsonHandler
+import com.openai.core.handlers.map
+import com.openai.core.handlers.mapJson
+import com.openai.core.handlers.sseHandler
import com.openai.core.handlers.withErrorHandler
+import com.openai.core.http.AsyncStreamResponse
import com.openai.core.http.HttpMethod
import com.openai.core.http.HttpRequest
import com.openai.core.http.HttpResponse.Handler
-import com.openai.core.isAzureEndpoint
+import com.openai.core.http.StreamResponse
+import com.openai.core.http.toAsync
import com.openai.core.json
-import com.openai.credential.BearerTokenCredential
import com.openai.errors.OpenAIError
import com.openai.models.ChatCompletion
+import com.openai.models.ChatCompletionChunk
import com.openai.models.ChatCompletionCreateParams
import java.util.concurrent.CompletableFuture
@@ -41,23 +47,10 @@ constructor(
val request =
HttpRequest.builder()
.method(HttpMethod.POST)
- .apply {
- if (isAzureEndpoint(clientOptions.baseUrl)) {
- addPathSegments("openai", "deployments", params.model().toString())
- }
- }
.addPathSegments("chat", "completions")
.putAllQueryParams(clientOptions.queryParams)
.replaceAllQueryParams(params.getQueryParams())
.putAllHeaders(clientOptions.headers)
- .apply {
- if (
- isAzureEndpoint(clientOptions.baseUrl) &&
- clientOptions.credential is BearerTokenCredential
- ) {
- putHeader("Authorization", "Bearer ${clientOptions.credential.token()}")
- }
- }
.replaceAllHeaders(params.getHeaders())
.body(json(clientOptions.jsonMapper, params.getBody()))
.build()
@@ -72,4 +65,54 @@ constructor(
}
}
}
+
+ private val createStreamingHandler: Handler> =
+ sseHandler(clientOptions.jsonMapper)
+ .mapJson()
+ .withErrorHandler(errorHandler)
+
+ /**
+ * Creates a model response for the given chat conversation. Learn more in the
+ * [text generation](https://platform.openai.com/docs/guides/text-generation),
+ * [vision](https://platform.openai.com/docs/guides/vision), and
+ * [audio](https://platform.openai.com/docs/guides/audio) guides.
+ */
+ override fun createStreaming(
+ params: ChatCompletionCreateParams,
+ requestOptions: RequestOptions
+ ): AsyncStreamResponse {
+ val request =
+ HttpRequest.builder()
+ .method(HttpMethod.POST)
+ .addPathSegments("chat", "completions")
+ .putAllQueryParams(clientOptions.queryParams)
+ .replaceAllQueryParams(params.getQueryParams())
+ .putAllHeaders(clientOptions.headers)
+ .replaceAllHeaders(params.getHeaders())
+ .body(
+ json(
+ clientOptions.jsonMapper,
+ params
+ .getBody()
+ .toBuilder()
+ .putAdditionalProperty("stream", JsonValue.from(true))
+ .build()
+ )
+ )
+ .build()
+ return clientOptions.httpClient
+ .executeAsync(request, requestOptions)
+ .thenApply { response ->
+ response
+ .let { createStreamingHandler.handle(it) }
+ .let { streamResponse ->
+ if (requestOptions.responseValidation ?: clientOptions.responseValidation) {
+ streamResponse.map { it.validate() }
+ } else {
+ streamResponse
+ }
+ }
+ }
+ .toAsync(clientOptions.streamHandlerExecutor)
+ }
}
diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/CompletionServiceImpl.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/CompletionServiceImpl.kt
index f6ea7c91f..4052936fb 100644
--- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/CompletionServiceImpl.kt
+++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/CompletionServiceImpl.kt
@@ -15,9 +15,7 @@ import com.openai.core.http.HttpMethod
import com.openai.core.http.HttpRequest
import com.openai.core.http.HttpResponse.Handler
import com.openai.core.http.StreamResponse
-import com.openai.core.isAzureEndpoint
import com.openai.core.json
-import com.openai.credential.BearerTokenCredential
import com.openai.errors.OpenAIError
import com.openai.models.ChatCompletion
import com.openai.models.ChatCompletionChunk
@@ -46,23 +44,10 @@ constructor(
val request =
HttpRequest.builder()
.method(HttpMethod.POST)
- .apply {
- if (isAzureEndpoint(clientOptions.baseUrl)) {
- addPathSegments("openai", "deployments", params.model().toString())
- }
- }
.addPathSegments("chat", "completions")
.putAllQueryParams(clientOptions.queryParams)
.replaceAllQueryParams(params.getQueryParams())
.putAllHeaders(clientOptions.headers)
- .apply {
- if (
- isAzureEndpoint(clientOptions.baseUrl) &&
- clientOptions.credential is BearerTokenCredential
- ) {
- putHeader("Authorization", "Bearer ${clientOptions.credential.token()}")
- }
- }
.replaceAllHeaders(params.getHeaders())
.body(json(clientOptions.jsonMapper, params.getBody()))
.build()
diff --git a/openai-java-core/src/test/kotlin/com/openai/core/http/AsyncStreamResponseTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/http/AsyncStreamResponseTest.kt
new file mode 100644
index 000000000..257951a2e
--- /dev/null
+++ b/openai-java-core/src/test/kotlin/com/openai/core/http/AsyncStreamResponseTest.kt
@@ -0,0 +1,165 @@
+package com.openai.core.http
+
+import java.util.*
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.Executor
+import java.util.stream.Stream
+import kotlin.streams.asStream
+import org.assertj.core.api.Assertions.assertThat
+import org.assertj.core.api.Assertions.catchThrowable
+import org.junit.jupiter.api.Test
+import org.junit.jupiter.api.assertDoesNotThrow
+import org.junit.jupiter.api.extension.ExtendWith
+import org.mockito.junit.jupiter.MockitoExtension
+import org.mockito.kotlin.*
+
+@ExtendWith(MockitoExtension::class)
+internal class AsyncStreamResponseTest {
+
+ companion object {
+ private val ERROR = RuntimeException("ERROR!")
+ }
+
+ private val streamResponse =
+ spy> {
+ doReturn(Stream.of("chunk1", "chunk2", "chunk3")).whenever(it).stream()
+ }
+ private val erroringStreamResponse =
+ spy> {
+ doReturn(
+ sequence {
+ yield("chunk1")
+ yield("chunk2")
+ throw ERROR
+ }
+ .asStream()
+ )
+ .whenever(it)
+ .stream()
+ }
+ private val executor =
+ spy {
+ doAnswer { invocation -> invocation.getArgument(0).run() }
+ .whenever(it)
+ .execute(any())
+ }
+ private val handler = mock>()
+
+ @Test
+ fun subscribe_whenAlreadySubscribed_throws() {
+ val asyncStreamResponse = CompletableFuture>().toAsync(executor)
+ asyncStreamResponse.subscribe {}
+
+ val throwable = catchThrowable { asyncStreamResponse.subscribe {} }
+
+ assertThat(throwable).isInstanceOf(IllegalStateException::class.java)
+ assertThat(throwable).hasMessage("Cannot subscribe more than once")
+ verify(executor, never()).execute(any())
+ }
+
+ @Test
+ fun subscribe_whenClosed_throws() {
+ val asyncStreamResponse = CompletableFuture>().toAsync(executor)
+ asyncStreamResponse.close()
+
+ val throwable = catchThrowable { asyncStreamResponse.subscribe {} }
+
+ assertThat(throwable).isInstanceOf(IllegalStateException::class.java)
+ assertThat(throwable).hasMessage("Cannot subscribe after the response is closed")
+ verify(executor, never()).execute(any())
+ }
+
+ @Test
+ fun subscribe_whenFutureCompletesAfterClose_doesNothing() {
+ val future = CompletableFuture>()
+ val asyncStreamResponse = future.toAsync(executor)
+ asyncStreamResponse.subscribe(handler)
+ asyncStreamResponse.close()
+
+ future.complete(streamResponse)
+
+ verify(handler, never()).onNext(any())
+ verify(handler, never()).onComplete(any())
+ verify(executor, times(1)).execute(any())
+ }
+
+ @Test
+ fun subscribe_whenFutureErrors_callsOnComplete() {
+ val future = CompletableFuture>()
+ val asyncStreamResponse = future.toAsync(executor)
+ asyncStreamResponse.subscribe(handler)
+
+ future.completeExceptionally(ERROR)
+
+ verify(handler, never()).onNext(any())
+ verify(handler, times(1)).onComplete(Optional.of(ERROR))
+ verify(executor, times(1)).execute(any())
+ }
+
+ @Test
+ fun subscribe_whenFutureCompletes_runsHandler() {
+ val future = CompletableFuture>()
+ val asyncStreamResponse = future.toAsync(executor)
+ asyncStreamResponse.subscribe(handler)
+
+ future.complete(streamResponse)
+
+ inOrder(handler, streamResponse) {
+ verify(handler, times(1)).onNext("chunk1")
+ verify(handler, times(1)).onNext("chunk2")
+ verify(handler, times(1)).onNext("chunk3")
+ verify(handler, times(1)).onComplete(Optional.empty())
+ verify(streamResponse, times(1)).close()
+ }
+ verify(executor, times(1)).execute(any())
+ }
+
+ @Test
+ fun subscribe_whenStreamErrors_callsOnCompleteEarly() {
+ val future = CompletableFuture>()
+ val asyncStreamResponse = future.toAsync(executor)
+ asyncStreamResponse.subscribe(handler)
+
+ future.complete(erroringStreamResponse)
+
+ inOrder(handler, erroringStreamResponse) {
+ verify(handler, times(1)).onNext("chunk1")
+ verify(handler, times(1)).onNext("chunk2")
+ verify(handler, times(1)).onComplete(Optional.of(ERROR))
+ verify(erroringStreamResponse, times(1)).close()
+ }
+ verify(executor, times(1)).execute(any())
+ }
+
+ @Test
+ fun close_whenNotClosed_closesStreamResponse() {
+ val future = CompletableFuture>()
+ val asyncStreamResponse = future.toAsync(executor)
+
+ asyncStreamResponse.close()
+ future.complete(streamResponse)
+
+ verify(streamResponse, times(1)).close()
+ }
+
+ @Test
+ fun close_whenAlreadyClosed_doesNothing() {
+ val future = CompletableFuture>()
+ val asyncStreamResponse = future.toAsync(executor)
+ asyncStreamResponse.close()
+ future.complete(streamResponse)
+
+ asyncStreamResponse.close()
+
+ verify(streamResponse, times(1)).close()
+ }
+
+ @Test
+ fun close_whenFutureErrors_doesNothing() {
+ val future = CompletableFuture>()
+ val asyncStreamResponse = future.toAsync(executor)
+ asyncStreamResponse.close()
+
+ assertDoesNotThrow { future.completeExceptionally(ERROR) }
+ }
+}
diff --git a/openai-java-core/src/test/kotlin/com/openai/core/http/ClientOptionsTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/http/ClientOptionsTest.kt
deleted file mode 100644
index d53f5d8f9..000000000
--- a/openai-java-core/src/test/kotlin/com/openai/core/http/ClientOptionsTest.kt
+++ /dev/null
@@ -1,70 +0,0 @@
-package com.openai.core.http
-
-import com.openai.azure.credential.AzureApiKeyCredential
-import com.openai.client.okhttp.OkHttpClient
-import com.openai.core.ClientOptions
-import com.openai.credential.BearerTokenCredential
-import java.util.stream.Stream
-import kotlin.test.Test
-import org.assertj.core.api.Assertions.assertThat
-import org.assertj.core.api.Assertions.assertThatThrownBy
-import org.junit.jupiter.params.ParameterizedTest
-import org.junit.jupiter.params.provider.MethodSource
-
-internal class ClientOptionsTest {
-
- companion object {
- private const val FAKE_API_KEY = "test-api-key"
-
- @JvmStatic
- private fun createOkHttpClient(baseUrl: String): OkHttpClient {
- return OkHttpClient.builder().baseUrl(baseUrl).build()
- }
-
- @JvmStatic
- private fun provideBaseUrls(): Stream {
- return Stream.of(
- "https://api.openai.com/v1",
- "https://example.openai.azure.com",
- "https://example.azure-api.net"
- )
- }
- }
-
- @ParameterizedTest
- @MethodSource("provideBaseUrls")
- fun clientOptionsWithoutBaseUrl(baseUrl: String) {
- // Arrange
- val apiKey = FAKE_API_KEY
-
- // Act
- val clientOptions =
- ClientOptions.builder()
- .httpClient(createOkHttpClient(baseUrl))
- .credential(BearerTokenCredential.create(apiKey))
- .build()
-
- // Assert
- assertThat(clientOptions.baseUrl).isEqualTo(ClientOptions.PRODUCTION_URL)
- }
-
- @ParameterizedTest
- @MethodSource("provideBaseUrls")
- fun throwExceptionWhenNullCredential(baseUrl: String) {
- // Act
- val clientOptionsBuilder =
- ClientOptions.builder().httpClient(createOkHttpClient(baseUrl)).baseUrl(baseUrl)
-
- // Assert
- assertThatThrownBy { clientOptionsBuilder.build() }
- .isInstanceOf(IllegalStateException::class.java)
- .hasMessage("`credential` is required but was not set")
- }
-
- @Test
- fun throwExceptionWhenEmptyCredential() {
- assertThatThrownBy { AzureApiKeyCredential.create("") }
- .isInstanceOf(IllegalArgumentException::class.java)
- .hasMessage("Azure API key cannot be empty.")
- }
-}
diff --git a/openai-java-core/src/test/kotlin/com/openai/services/ErrorHandlingTest.kt b/openai-java-core/src/test/kotlin/com/openai/services/ErrorHandlingTest.kt
index c9d34bb02..17dc53f01 100644
--- a/openai-java-core/src/test/kotlin/com/openai/services/ErrorHandlingTest.kt
+++ b/openai-java-core/src/test/kotlin/com/openai/services/ErrorHandlingTest.kt
@@ -31,7 +31,6 @@ import com.openai.models.*
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.assertj.core.api.InstanceOfAssertFactories
-import org.assertj.guava.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
diff --git a/openai-java-core/src/test/kotlin/com/openai/services/blocking/fineTuning/jobs/CheckpointServiceTest.kt b/openai-java-core/src/test/kotlin/com/openai/services/blocking/fineTuning/jobs/CheckpointServiceTest.kt
index d779133bd..5ed44e5f2 100644
--- a/openai-java-core/src/test/kotlin/com/openai/services/blocking/fineTuning/jobs/CheckpointServiceTest.kt
+++ b/openai-java-core/src/test/kotlin/com/openai/services/blocking/fineTuning/jobs/CheckpointServiceTest.kt
@@ -4,6 +4,7 @@ package com.openai.services.blocking.fineTuning.jobs
import com.openai.TestServerExtension
import com.openai.client.okhttp.OpenAIOkHttpClient
+import com.openai.models.*
import com.openai.models.FineTuningJobCheckpointListParams
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
diff --git a/settings.gradle.kts b/settings.gradle.kts
index 58e9de020..3c5725fb3 100644
--- a/settings.gradle.kts
+++ b/settings.gradle.kts
@@ -4,4 +4,3 @@ include("openai-java")
include("openai-java-client-okhttp")
include("openai-java-core")
include("openai-java-example")
-include("openai-azure-java-example")