Skip to content

Commit 2cf89fb

Browse files
committed
feat: add support for custom headers in Azure OpenAI requests
1 parent f2ce000 commit 2cf89fb

File tree

1 file changed

+38
-3
lines changed

1 file changed

+38
-3
lines changed

arc-azure-client/src/main/kotlin/AzureAIClient.kt

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import com.azure.ai.openai.models.ChatRequestUserMessage
1919
import com.azure.ai.openai.models.EmbeddingsOptions
2020
import com.azure.ai.openai.models.ReasoningEffortValue
2121
import com.azure.core.exception.ClientAuthenticationException
22+
import com.azure.core.http.HttpHeaderName
23+
import com.azure.core.http.rest.RequestOptions
2224
import com.azure.core.util.BinaryData.fromObject
2325
import com.azure.core.util.BinaryData.fromString
2426
import kotlinx.coroutines.reactive.awaitFirst
@@ -42,7 +44,6 @@ import org.eclipse.lmos.arc.agents.llm.ChatCompletionSettings
4244
import org.eclipse.lmos.arc.agents.llm.LLMStartedEvent
4345
import org.eclipse.lmos.arc.agents.llm.OutputFormat.JSON
4446
import org.eclipse.lmos.arc.agents.llm.OutputSchema
45-
import org.eclipse.lmos.arc.agents.llm.ReasoningEffort
4647
import org.eclipse.lmos.arc.agents.llm.ReasoningEffort.HIGH
4748
import org.eclipse.lmos.arc.agents.llm.ReasoningEffort.LOW
4849
import org.eclipse.lmos.arc.agents.llm.ReasoningEffort.MEDIUM
@@ -63,6 +64,7 @@ import org.slf4j.LoggerFactory
6364
import sh.ondr.koja.Schema
6465
import sh.ondr.koja.toJsonElement
6566
import sh.ondr.koja.toSchema
67+
import java.util.concurrent.atomic.AtomicReference
6668
import kotlin.contracts.ExperimentalContracts
6769
import kotlin.contracts.InvocationKind.EXACTLY_ONCE
6870
import kotlin.contracts.contract
@@ -213,7 +215,13 @@ class AzureAIClient(
213215
val result: Result<ChatCompletions, ArcException>
214216
val duration = measureTime {
215217
result = result<ChatCompletions, ArcException> {
216-
client.getChatCompletions(deploymentOrModelName, chatCompletionsOptions).awaitFirst()
218+
val requestOptions = getCustomRequestOptions(chatCompletionsOptions)
219+
if (requestOptions != null) {
220+
client.getChatCompletionsWithResponse(deploymentOrModelName, chatCompletionsOptions, requestOptions)
221+
.awaitFirst().getValue()
222+
} else {
223+
client.getChatCompletions(deploymentOrModelName, chatCompletionsOptions).awaitFirst()
224+
}
217225
}.mapFailure {
218226
log.error("Calling Azure OpenAI failed!", it)
219227
mapOpenAIException(it)
@@ -222,6 +230,17 @@ class AzureAIClient(
222230
return result to duration
223231
}
224232

233+
private fun getCustomRequestOptions(chatCompletionsOptions: ChatCompletionsOptions): RequestOptions? {
234+
val customHeaders =
235+
AzureOpenAICustomHeaders.getCustomHeaders(chatCompletionsOptions).takeIf { it.isNotEmpty() } ?: return null
236+
log.debug("Adding custom headers: $customHeaders")
237+
val requestOptions = RequestOptions()
238+
customHeaders.forEach { key, value ->
239+
requestOptions.addHeader(HttpHeaderName.fromString(key), value)
240+
}
241+
return requestOptions
242+
}
243+
225244
@OptIn(InternalSerializationApi::class)
226245
private fun OutputSchema.toSchema(): String {
227246
return (type.serializer().descriptor.toSchema() as Schema.ObjectSchema).copy(
@@ -258,7 +277,8 @@ class AzureAIClient(
258277
}
259278
}
260279
settings?.maxTokens?.let { maxTokens = it }
261-
settings?.format?.takeIf { JSON == it }?.let { responseFormat = ChatCompletionsJsonResponseFormat() }
280+
settings?.format?.takeIf { JSON == it && responseFormat == null }
281+
?.let { responseFormat = ChatCompletionsJsonResponseFormat() }
262282
if (openAIFunctions != null) tools = openAIFunctions
263283
}
264284

@@ -312,3 +332,18 @@ fun String.extractUseCaseId(): String? {
312332
null
313333
}
314334
}
335+
336+
/**
337+
* Provides a way to set custom headers for Azure OpenAI requests.
338+
*/
339+
object AzureOpenAICustomHeaders {
340+
341+
private val customHeaderProvider = AtomicReference<(ChatCompletionsOptions) -> Map<String, String>> { emptyMap() }
342+
343+
fun setCustomHeaderProvider(provider: (ChatCompletionsOptions) -> Map<String, String>) {
344+
customHeaderProvider.set(provider)
345+
}
346+
347+
fun getCustomHeaders(chatCompletionsOptions: ChatCompletionsOptions): Map<String, String> =
348+
customHeaderProvider.get().invoke(chatCompletionsOptions)
349+
}

0 commit comments

Comments
 (0)