Skip to content

Commit d9f6033

Browse files
RobertCraigiestainless-app[bot]
authored andcommitted
fix(azure): add missing azure changes
1 parent 05b5395 commit d9f6033

File tree

13 files changed

+294
-13
lines changed

13 files changed

+294
-13
lines changed

openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClient.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
package com.openai.client.okhttp
44

55
import com.fasterxml.jackson.databind.json.JsonMapper
6+
import com.openai.azure.AzureOpenAIServiceVersion
67
import com.openai.client.OpenAIClient
78
import com.openai.client.OpenAIClientImpl
89
import com.openai.core.ClientOptions
910
import com.openai.core.http.Headers
1011
import com.openai.core.http.QueryParams
12+
import com.openai.credential.Credential
1113
import java.net.Proxy
1214
import java.time.Clock
1315
import java.time.Duration
@@ -130,6 +132,12 @@ class OpenAIOkHttpClient private constructor() {
130132

131133
fun apiKey(apiKey: String) = apply { clientOptions.apiKey(apiKey) }
132134

135+
fun credential(credential: Credential) = apply { clientOptions.credential(credential) }
136+
137+
fun azureServiceVersion(azureServiceVersion: AzureOpenAIServiceVersion) = apply {
138+
clientOptions.azureServiceVersion(azureServiceVersion)
139+
}
140+
133141
fun organization(organization: String?) = apply { clientOptions.organization(organization) }
134142

135143
fun project(project: String?) = apply { clientOptions.project(project) }

openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClientAsync.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
package com.openai.client.okhttp
44

55
import com.fasterxml.jackson.databind.json.JsonMapper
6+
import com.openai.azure.AzureOpenAIServiceVersion
67
import com.openai.client.OpenAIClientAsync
78
import com.openai.client.OpenAIClientAsyncImpl
89
import com.openai.core.ClientOptions
910
import com.openai.core.http.Headers
1011
import com.openai.core.http.QueryParams
12+
import com.openai.credential.Credential
1113
import java.net.Proxy
1214
import java.time.Clock
1315
import java.time.Duration
@@ -130,6 +132,12 @@ class OpenAIOkHttpClientAsync private constructor() {
130132

131133
fun apiKey(apiKey: String) = apply { clientOptions.apiKey(apiKey) }
132134

135+
fun credential(credential: Credential) = apply { clientOptions.credential(credential) }
136+
137+
fun azureServiceVersion(azureServiceVersion: AzureOpenAIServiceVersion) = apply {
138+
clientOptions.azureServiceVersion(azureServiceVersion)
139+
}
140+
133141
fun organization(organization: String?) = apply { clientOptions.organization(organization) }
134142

135143
fun project(project: String?) = apply { clientOptions.project(project) }
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.openai.azure
2+
3+
import java.util.concurrent.ConcurrentHashMap
4+
5+
class AzureOpenAIServiceVersion private constructor(@get:JvmName("value") val value: String) {
6+
7+
companion object {
8+
private val values: ConcurrentHashMap<String, AzureOpenAIServiceVersion> =
9+
ConcurrentHashMap()
10+
11+
@JvmStatic
12+
fun fromString(version: String): AzureOpenAIServiceVersion =
13+
values.computeIfAbsent(version) { AzureOpenAIServiceVersion(version) }
14+
15+
@JvmStatic val V2022_12_01 = fromString("2022-12-01")
16+
@JvmStatic val V2023_05_15 = fromString("2023-05-15")
17+
@JvmStatic val V2024_02_01 = fromString("2024-02-01")
18+
@JvmStatic val V2024_06_01 = fromString("2024-06-01")
19+
@JvmStatic val V2023_06_01_PREVIEW = fromString("2023-06-01-preview")
20+
@JvmStatic val V2023_07_01_PREVIEW = fromString("2023-07-01-preview")
21+
@JvmStatic val V2024_02_15_PREVIEW = fromString("2024-02-15-preview")
22+
@JvmStatic val V2024_03_01_PREVIEW = fromString("2024-03-01-preview")
23+
@JvmStatic val V2024_04_01_PREVIEW = fromString("2024-04-01-preview")
24+
@JvmStatic val V2024_05_01_PREVIEW = fromString("2024-05-01-preview")
25+
@JvmStatic val V2024_07_01_PREVIEW = fromString("2024-07-01-preview")
26+
@JvmStatic val V2024_08_01_PREVIEW = fromString("2024-08-01-preview")
27+
@JvmStatic val V2024_09_01_PREVIEW = fromString("2024-09-01-preview")
28+
}
29+
30+
override fun equals(other: Any?): Boolean =
31+
this === other || (other is AzureOpenAIServiceVersion && value == other.value)
32+
33+
override fun hashCode(): Int = value.hashCode()
34+
35+
override fun toString(): String = "AzureOpenAIServiceVersion{value=$value}"
36+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package com.openai.azure.credential
2+
3+
import com.openai.credential.Credential
4+
5+
/** A credential that provides an Azure API key. */
6+
class AzureApiKeyCredential private constructor(private var apiKey: String) : Credential {
7+
8+
init {
9+
validateApiKey(apiKey)
10+
}
11+
12+
companion object {
13+
@JvmStatic fun create(apiKey: String): Credential = AzureApiKeyCredential(apiKey)
14+
15+
private fun validateApiKey(apiKey: String) {
16+
require(apiKey.isNotEmpty()) { "Azure API key cannot be empty." }
17+
}
18+
}
19+
20+
fun apiKey(): String = apiKey
21+
22+
fun update(apiKey: String) = apply {
23+
validateApiKey(apiKey)
24+
this.apiKey = apiKey
25+
}
26+
}

openai-java-core/src/main/kotlin/com/openai/core/ClientOptions.kt

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
package com.openai.core
44

55
import com.fasterxml.jackson.databind.json.JsonMapper
6+
import com.openai.azure.AzureOpenAIServiceVersion
7+
import com.openai.azure.AzureOpenAIServiceVersion.Companion.V2024_06_01
8+
import com.openai.azure.credential.AzureApiKeyCredential
69
import com.openai.core.http.Headers
710
import com.openai.core.http.HttpClient
811
import com.openai.core.http.PhantomReachableClosingHttpClient
912
import com.openai.core.http.QueryParams
1013
import com.openai.core.http.RetryingHttpClient
14+
import com.openai.credential.BearerTokenCredential
15+
import com.openai.credential.Credential
1116
import java.time.Clock
1217
import java.util.concurrent.Executor
1318
import java.util.concurrent.Executors
@@ -26,7 +31,7 @@ private constructor(
2631
@get:JvmName("queryParams") val queryParams: QueryParams,
2732
@get:JvmName("responseValidation") val responseValidation: Boolean,
2833
@get:JvmName("maxRetries") val maxRetries: Int,
29-
@get:JvmName("apiKey") val apiKey: String,
34+
@get:JvmName("credential") val credential: Credential,
3035
@get:JvmName("organization") val organization: String?,
3136
@get:JvmName("project") val project: String?,
3237
) {
@@ -53,7 +58,8 @@ private constructor(
5358
private var queryParams: QueryParams.Builder = QueryParams.builder()
5459
private var responseValidation: Boolean = false
5560
private var maxRetries: Int = 2
56-
private var apiKey: String? = null
61+
private var credential: Credential? = null
62+
private var azureServiceVersion: AzureOpenAIServiceVersion? = null
5763
private var organization: String? = null
5864
private var project: String? = null
5965

@@ -68,7 +74,7 @@ private constructor(
6874
queryParams = clientOptions.queryParams.toBuilder()
6975
responseValidation = clientOptions.responseValidation
7076
maxRetries = clientOptions.maxRetries
71-
apiKey = clientOptions.apiKey
77+
credential = clientOptions.credential
7278
organization = clientOptions.organization
7379
project = clientOptions.project
7480
}
@@ -171,21 +177,56 @@ private constructor(
171177

172178
fun maxRetries(maxRetries: Int) = apply { this.maxRetries = maxRetries }
173179

174-
fun apiKey(apiKey: String) = apply { this.apiKey = apiKey }
180+
fun apiKey(apiKey: String) = apply {
181+
this.credential = BearerTokenCredential.create(apiKey)
182+
}
183+
184+
fun credential(credential: Credential) = apply { this.credential = credential }
185+
186+
fun azureServiceVersion(azureServiceVersion: AzureOpenAIServiceVersion) = apply {
187+
this.azureServiceVersion = azureServiceVersion
188+
}
175189

176190
fun organization(organization: String?) = apply { this.organization = organization }
177191

178192
fun project(project: String?) = apply { this.project = project }
179193

180194
fun fromEnv() = apply {
181-
System.getenv("OPENAI_API_KEY")?.let { apiKey(it) }
182-
System.getenv("OPENAI_ORG_ID")?.let { organization(it) }
183-
System.getenv("OPENAI_PROJECT_ID")?.let { project(it) }
195+
val openAIKey = System.getenv("OPENAI_API_KEY")
196+
val openAIOrgId = System.getenv("OPENAI_ORG_ID")
197+
val openAIProjectId = System.getenv("OPENAI_PROJECT_ID")
198+
val azureOpenAIKey = System.getenv("AZURE_OPENAI_KEY")
199+
val azureEndpoint = System.getenv("AZURE_OPENAI_ENDPOINT")
200+
201+
when {
202+
!openAIKey.isNullOrEmpty() && !azureOpenAIKey.isNullOrEmpty() -> {
203+
throw IllegalArgumentException(
204+
"Both OpenAI and Azure OpenAI API keys, `OPENAI_API_KEY` and `AZURE_OPENAI_KEY`, are set. Please specify only one"
205+
)
206+
}
207+
!openAIKey.isNullOrEmpty() -> {
208+
credential(BearerTokenCredential.create(openAIKey))
209+
organization(openAIOrgId)
210+
project(openAIProjectId)
211+
}
212+
!azureOpenAIKey.isNullOrEmpty() -> {
213+
credential(AzureApiKeyCredential.create(azureOpenAIKey))
214+
baseUrl(azureEndpoint)
215+
}
216+
!azureEndpoint.isNullOrEmpty() -> {
217+
// Both 'openAIKey' and 'azureOpenAIKey' are not set.
218+
// Only 'azureEndpoint' is set here, and user still needs to call method
219+
// '.credential(BearerTokenCredential(Supplier<String>))'
220+
// to get the token through the supplier, which requires Azure Entra ID as a
221+
// dependency.
222+
baseUrl(azureEndpoint)
223+
}
224+
}
184225
}
185226

186227
fun build(): ClientOptions {
187228
checkNotNull(httpClient) { "`httpClient` is required but was not set" }
188-
checkNotNull(apiKey) { "`apiKey` is required but was not set" }
229+
checkNotNull(credential) { "`credential` is required but was not set" }
189230

190231
val headers = Headers.builder()
191232
val queryParams = QueryParams.builder()
@@ -198,11 +239,26 @@ private constructor(
198239
headers.put("X-Stainless-Runtime-Version", getJavaVersion())
199240
organization?.let { headers.put("OpenAI-Organization", it) }
200241
project?.let { headers.put("OpenAI-Project", it) }
201-
apiKey?.let {
202-
if (!it.isEmpty()) {
203-
headers.put("Authorization", "Bearer $it")
242+
243+
when (val currentCredential = credential) {
244+
is AzureApiKeyCredential -> {
245+
headers.put("api-key", currentCredential.apiKey())
246+
}
247+
is BearerTokenCredential -> {
248+
headers.put("Authorization", "Bearer ${currentCredential.token()}")
249+
}
250+
else -> {
251+
throw IllegalArgumentException("Invalid credential type")
204252
}
205253
}
254+
255+
if (isAzureEndpoint(baseUrl)) {
256+
// Default Azure OpenAI version is used if Azure user doesn't
257+
// specific a service API version in 'queryParams'.
258+
// We can update the default value every major announcement if needed.
259+
replaceQueryParams("api-version", (azureServiceVersion ?: V2024_06_01).value)
260+
}
261+
206262
headers.replaceAll(this.headers.build())
207263
queryParams.replaceAll(this.queryParams.build())
208264

@@ -237,7 +293,7 @@ private constructor(
237293
queryParams.build(),
238294
responseValidation,
239295
maxRetries,
240-
apiKey!!,
296+
credential!!,
241297
organization,
242298
project,
243299
)

openai-java-core/src/main/kotlin/com/openai/core/Utils.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,11 @@ internal fun <K : Comparable<K>, V> SortedMap<K, V>.toImmutable(): SortedMap<K,
2323
if (isEmpty()) Collections.emptySortedMap()
2424
else Collections.unmodifiableSortedMap(toSortedMap(comparator()))
2525

26+
@JvmSynthetic
27+
internal fun isAzureEndpoint(baseUrl: String): Boolean {
28+
// Azure Endpoint should be in the format of `https://<region>.openai.azure.com`.
29+
// Or `https://<region>.azure-api.net` for Azure OpenAI Management URL.
30+
return baseUrl.endsWith(".openai.azure.com", true) || baseUrl.endsWith(".azure-api.net", true)
31+
}
32+
2633
internal interface Enum
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.openai.credential
2+
3+
import java.util.function.Supplier
4+
5+
/**
6+
* <p> A credential that provides a bearer token. </p>
7+
*
8+
* <p>
9+
* If you are using the OpenAI API, you need to provide a bearer token for authentication. All API
10+
* requests should include your API key in an Authorization HTTP header as follows: "Authorization:
11+
* Bearer OPENAI_API_KEY" </p>
12+
*
13+
* <p> Two ways to provide the token: </p>
14+
* <ol>
15+
* <l1> 1. Provide the token directly, 'BearerTokenCredential.create(String)'. The method
16+
* 'ClientOptions.apiKey(String)' is a wrapper for this.</li> <l1> 2. Provide a supplier that
17+
* provides the token, 'BearerTokenCredential.create(Supplier<String>)'.</li>
18+
* </ol>
19+
*
20+
* @param tokenSupplier a supplier that provides the token.
21+
* @see <a href="https://platform.openai.com/docs/api-reference/authentication">OpenAI
22+
* Authentication</a>
23+
*/
24+
class BearerTokenCredential private constructor(private val tokenSupplier: Supplier<String>) :
25+
Credential {
26+
27+
companion object {
28+
@JvmStatic fun create(token: String): Credential = BearerTokenCredential { token }
29+
30+
@JvmStatic
31+
fun create(tokenSupplier: Supplier<String>): Credential =
32+
BearerTokenCredential(tokenSupplier)
33+
}
34+
35+
fun token(): String = tokenSupplier.get()
36+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
package com.openai.credential
2+
3+
/** An interface that represents a credential. */
4+
interface Credential

openai-java-core/src/main/kotlin/com/openai/services/async/chat/CompletionServiceAsyncImpl.kt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ import com.openai.core.http.HttpRequest
1717
import com.openai.core.http.HttpResponse.Handler
1818
import com.openai.core.http.StreamResponse
1919
import com.openai.core.http.toAsync
20+
import com.openai.core.isAzureEndpoint
2021
import com.openai.core.json
22+
import com.openai.credential.BearerTokenCredential
2123
import com.openai.errors.OpenAIError
2224
import com.openai.models.ChatCompletion
2325
import com.openai.models.ChatCompletionChunk
@@ -47,10 +49,23 @@ constructor(
4749
val request =
4850
HttpRequest.builder()
4951
.method(HttpMethod.POST)
52+
.apply {
53+
if (isAzureEndpoint(clientOptions.baseUrl)) {
54+
addPathSegments("openai", "deployments", params.model().toString())
55+
}
56+
}
5057
.addPathSegments("chat", "completions")
5158
.putAllQueryParams(clientOptions.queryParams)
5259
.replaceAllQueryParams(params.getQueryParams())
5360
.putAllHeaders(clientOptions.headers)
61+
.apply {
62+
if (
63+
isAzureEndpoint(clientOptions.baseUrl) &&
64+
clientOptions.credential is BearerTokenCredential
65+
) {
66+
putHeader("Authorization", "Bearer ${clientOptions.credential.token()}")
67+
}
68+
}
5469
.replaceAllHeaders(params.getHeaders())
5570
.body(json(clientOptions.jsonMapper, params.getBody()))
5671
.build()

openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/CompletionServiceImpl.kt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ import com.openai.core.http.HttpMethod
1515
import com.openai.core.http.HttpRequest
1616
import com.openai.core.http.HttpResponse.Handler
1717
import com.openai.core.http.StreamResponse
18+
import com.openai.core.isAzureEndpoint
1819
import com.openai.core.json
20+
import com.openai.credential.BearerTokenCredential
1921
import com.openai.errors.OpenAIError
2022
import com.openai.models.ChatCompletion
2123
import com.openai.models.ChatCompletionChunk
@@ -44,10 +46,23 @@ constructor(
4446
val request =
4547
HttpRequest.builder()
4648
.method(HttpMethod.POST)
49+
.apply {
50+
if (isAzureEndpoint(clientOptions.baseUrl)) {
51+
addPathSegments("openai", "deployments", params.model().toString())
52+
}
53+
}
4754
.addPathSegments("chat", "completions")
4855
.putAllQueryParams(clientOptions.queryParams)
4956
.replaceAllQueryParams(params.getQueryParams())
5057
.putAllHeaders(clientOptions.headers)
58+
.apply {
59+
if (
60+
isAzureEndpoint(clientOptions.baseUrl) &&
61+
clientOptions.credential is BearerTokenCredential
62+
) {
63+
putHeader("Authorization", "Bearer ${clientOptions.credential.token()}")
64+
}
65+
}
5166
.replaceAllHeaders(params.getHeaders())
5267
.body(json(clientOptions.jsonMapper, params.getBody()))
5368
.build()

0 commit comments

Comments
 (0)