@@ -19,6 +19,8 @@ import com.azure.ai.openai.models.ChatRequestUserMessage
1919import com.azure.ai.openai.models.EmbeddingsOptions
2020import com.azure.ai.openai.models.ReasoningEffortValue
2121import com.azure.core.exception.ClientAuthenticationException
22+ import com.azure.core.http.HttpHeaderName
23+ import com.azure.core.http.rest.RequestOptions
2224import com.azure.core.util.BinaryData.fromObject
2325import com.azure.core.util.BinaryData.fromString
2426import kotlinx.coroutines.reactive.awaitFirst
@@ -42,7 +44,6 @@ import org.eclipse.lmos.arc.agents.llm.ChatCompletionSettings
4244import org.eclipse.lmos.arc.agents.llm.LLMStartedEvent
4345import org.eclipse.lmos.arc.agents.llm.OutputFormat.JSON
4446import org.eclipse.lmos.arc.agents.llm.OutputSchema
45- import org.eclipse.lmos.arc.agents.llm.ReasoningEffort
4647import org.eclipse.lmos.arc.agents.llm.ReasoningEffort.HIGH
4748import org.eclipse.lmos.arc.agents.llm.ReasoningEffort.LOW
4849import org.eclipse.lmos.arc.agents.llm.ReasoningEffort.MEDIUM
@@ -63,6 +64,7 @@ import org.slf4j.LoggerFactory
6364import sh.ondr.koja.Schema
6465import sh.ondr.koja.toJsonElement
6566import sh.ondr.koja.toSchema
67+ import java.util.concurrent.atomic.AtomicReference
6668import kotlin.contracts.ExperimentalContracts
6769import kotlin.contracts.InvocationKind.EXACTLY_ONCE
6870import kotlin.contracts.contract
@@ -213,7 +215,13 @@ class AzureAIClient(
213215 val result: Result <ChatCompletions , ArcException >
214216 val duration = measureTime {
215217 result = result<ChatCompletions , ArcException > {
216- client.getChatCompletions(deploymentOrModelName, chatCompletionsOptions).awaitFirst()
218+ val requestOptions = getCustomRequestOptions(chatCompletionsOptions)
219+ if (requestOptions != null ) {
220+ client.getChatCompletionsWithResponse(deploymentOrModelName, chatCompletionsOptions, requestOptions)
221+ .awaitFirst().getValue()
222+ } else {
223+ client.getChatCompletions(deploymentOrModelName, chatCompletionsOptions).awaitFirst()
224+ }
217225 }.mapFailure {
218226 log.error(" Calling Azure OpenAI failed!" , it)
219227 mapOpenAIException(it)
@@ -222,6 +230,17 @@ class AzureAIClient(
222230 return result to duration
223231 }
224232
233+ private fun getCustomRequestOptions (chatCompletionsOptions : ChatCompletionsOptions ): RequestOptions ? {
234+ val customHeaders =
235+ AzureOpenAICustomHeaders .getCustomHeaders(chatCompletionsOptions).takeIf { it.isNotEmpty() } ? : return null
236+ log.debug(" Adding custom headers: $customHeaders " )
237+ val requestOptions = RequestOptions ()
238+ customHeaders.forEach { key, value ->
239+ requestOptions.addHeader(HttpHeaderName .fromString(key), value)
240+ }
241+ return requestOptions
242+ }
243+
225244 @OptIn(InternalSerializationApi ::class )
226245 private fun OutputSchema.toSchema (): String {
227246 return (type.serializer().descriptor.toSchema() as Schema .ObjectSchema ).copy(
@@ -258,7 +277,8 @@ class AzureAIClient(
258277 }
259278 }
260279 settings?.maxTokens?.let { maxTokens = it }
261- settings?.format?.takeIf { JSON == it }?.let { responseFormat = ChatCompletionsJsonResponseFormat () }
280+ settings?.format?.takeIf { JSON == it && responseFormat == null }
281+ ?.let { responseFormat = ChatCompletionsJsonResponseFormat () }
262282 if (openAIFunctions != null ) tools = openAIFunctions
263283 }
264284
@@ -312,3 +332,18 @@ fun String.extractUseCaseId(): String? {
312332 null
313333 }
314334}
335+
336+ /* *
337+ * Provides a way to set custom headers for Azure OpenAI requests.
338+ */
339+ object AzureOpenAICustomHeaders {
340+
341+ private val customHeaderProvider = AtomicReference < (ChatCompletionsOptions ) -> Map <String , String >> { emptyMap() }
342+
343+ fun setCustomHeaderProvider (provider : (ChatCompletionsOptions ) -> Map <String , String >) {
344+ customHeaderProvider.set(provider)
345+ }
346+
347+ fun getCustomHeaders (chatCompletionsOptions : ChatCompletionsOptions ): Map <String , String > =
348+ customHeaderProvider.get().invoke(chatCompletionsOptions)
349+ }
0 commit comments