Skip to content

Commit 974375a

Browse files
Jonas-Isrnewtorkbot-sdk-js
authored
feat: Convenience for custom headers for orchestration and OpenAI (#550)
* orchestration client * openAI client * javadoc, release notes, etc * add note to FAQ in README.md * add actual links * Update foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java Co-authored-by: Alexander Dümont <22489773+newtork@users.noreply.github.com> * Formatting * add @SInCE * apply suggestions * apply review * apply suggestions * apply suggestions --------- Co-authored-by: Alexander Dümont <22489773+newtork@users.noreply.github.com> Co-authored-by: SAP Cloud SDK Bot <cloudsdk@sap.com>
1 parent 02d33e8 commit 974375a

File tree

7 files changed

+133
-15
lines changed

7 files changed

+133
-15
lines changed

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ For SAP internal development, you can also use `SNAPSHOT` builds from the [inter
159159

160160
The AI SDK leverages the destination concept from the SAP Cloud SDK to manage the connection to AI Core.
161161
This opens up a wide range of possibilities to customize the connection, including adding custom headers.
162+
The following shows how to add custom headers to all requests sent to AI Core.
162163

163164
```java
164165
var service = new AiCoreService();
@@ -170,15 +171,19 @@ var destination =
170171
// AI Core client
171172
service = service.withBaseDestination(destination);
172173
DeploymentApi client = new DeploymentApi(service);
174+
```
175+
176+
For more information, please refer to the [AI Core connectivity guide](https://sap.github.io/ai-sdk/docs/java/guides/connecting-to-ai-core) and the [SAP Cloud SDK documentation](https://sap.github.io/cloud-sdk/docs/java/features/connectivity/http-destinations).
173177

174-
// Orchestration client
175-
OrchestrationClient client = new OrchestrationClient(destination);
178+
There is also a convenient method to add custom headers to single calls through the Orchestration or OpenAI client.
176179

177-
// OpenAI client
178-
OpenAiClient client2 = OpenAiClient.withCustomDestination(destination);
180+
```java
181+
var client = new OrchestrationClient();
182+
183+
var result = client.withHeader("my-header-key", "my-header-value").chatCompletion(prompt, config);
179184
```
180185

181-
For more information, please refer to the [AI Core connectivity guide](https://sap.github.io/ai-sdk/docs/java/guides/connecting-to-ai-core) and the [SAP Cloud SDK documentation](https://sap.github.io/cloud-sdk/docs/java/features/connectivity/http-destinations).
186+
For more information on this feature, see the respective documentation of the [OrchestrationClient](https://sap.github.io/ai-sdk/docs/java/orchestration/chat-completion#custom-headers) and [OpenAIClient](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#custom-headers).
182187

183188
### _"There's a vulnerability warning `CVE-2021-41251`?"_
184189

docs/release_notes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
- Added `OpenAiChatModel`
2828
- [Prompt Registry] [Using Prompt Registry Templates in SpringAI.](https://sap.github.io/ai-sdk/docs/java/ai-core/prompt-registry#using-templates-in-springai)
2929
- Added `SpringAiConverter`
30+
- [Orchestration] [Added convenience to add custom headers to individual orchestration calls.](https://sap.github.io/ai-sdk/docs/java/orchestration/chat-completion#custom-headers)
31+
- [OpenAI] [Added convenience to add custom headers to individual LLM calls.](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#custom-headers)
3032

3133
### 📈 Improvements
3234

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@
2525
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
2626
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
2727
import com.sap.cloud.sdk.cloudplatform.connectivity.Destination;
28+
import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
2829
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
2930
import java.io.IOException;
31+
import java.util.ArrayList;
32+
import java.util.List;
3033
import java.util.stream.Stream;
3134
import javax.annotation.Nonnull;
3235
import javax.annotation.Nullable;
@@ -49,6 +52,7 @@ public final class OpenAiClient {
4952
@Nullable private String systemPrompt = null;
5053

5154
@Nonnull private final Destination destination;
55+
@Nonnull private final List<Header> customHeaders = new ArrayList<>();
5256

5357
/**
5458
* Create a new OpenAI client for the given foundation model, using the default resource group.
@@ -127,6 +131,23 @@ public OpenAiClient withSystemPrompt(@Nonnull final String systemPrompt) {
127131
return this;
128132
}
129133

134+
/**
135+
* Create a new OpenAI client with a custom header added to every call made with this client
136+
*
137+
* @param key the key of the custom header to add
138+
* @param value the value of the custom header to add
139+
* @return a new client.
140+
* @since 1.11.0
141+
*/
142+
@Beta
143+
@Nonnull
144+
public OpenAiClient withHeader(@Nonnull final String key, @Nonnull final String value) {
145+
final var newClient = new OpenAiClient(this.destination);
146+
newClient.customHeaders.addAll(this.customHeaders);
147+
newClient.customHeaders.add(new Header(key, value));
148+
return newClient;
149+
}
150+
130151
/**
131152
* Generate a completion for the given string prompt as user.
132153
*
@@ -395,7 +416,7 @@ private <T> T execute(
395416
@Nonnull final Object payload,
396417
@Nonnull final Class<T> responseType) {
397418
final var request = new HttpPost(path);
398-
serializeAndSetHttpEntity(request, payload);
419+
serializeAndSetHttpEntity(request, payload, this.customHeaders);
399420
return executeRequest(request, responseType);
400421
}
401422

@@ -405,15 +426,18 @@ private <D extends StreamedDelta> Stream<D> executeStream(
405426
@Nonnull final Object payload,
406427
@Nonnull final Class<D> deltaType) {
407428
final var request = new HttpPost(path);
408-
serializeAndSetHttpEntity(request, payload);
429+
serializeAndSetHttpEntity(request, payload, this.customHeaders);
409430
return streamRequest(request, deltaType);
410431
}
411432

412433
private static void serializeAndSetHttpEntity(
413-
@Nonnull final BasicClassicHttpRequest request, @Nonnull final Object payload) {
434+
@Nonnull final BasicClassicHttpRequest request,
435+
@Nonnull final Object payload,
436+
@Nonnull final List<Header> customHeaders) {
414437
try {
415438
final var json = JACKSON.writeValueAsString(payload);
416439
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
440+
customHeaders.forEach(h -> request.addHeader(h.getName(), h.getValue()));
417441
} catch (final JsonProcessingException e) {
418442
throw new OpenAiClientException("Failed to serialize request parameters", e)
419443
.setHttpRequest(request);

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.sap.ai.sdk.foundationmodels.openai;
22

33
import static com.github.tomakehurst.wiremock.client.WireMock.*;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
45
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool.ToolType.FUNCTION;
56
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.*;
67
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterSeverityResult.Severity.SAFE;
@@ -480,4 +481,30 @@ void chatCompletionTool() {
480481
}
481482
""")));
482483
}
484+
485+
@Test
486+
void testCustomHeaders() {
487+
stubForChatCompletion();
488+
final var request =
489+
new OpenAiChatCompletionRequest("Hello World! Why is this phrase so famous?");
490+
final var clientWithHeader = client.withHeader("Header-For-Both", "value");
491+
492+
final var result = clientWithHeader.withHeader("foo", "bar").chatCompletion(request);
493+
assertThat(result).isNotNull();
494+
495+
var streamResult =
496+
clientWithHeader
497+
.withHeader("foot", "baz")
498+
.streamChatCompletion("Hello World! Why is this phrase so famous?");
499+
assertThat(streamResult).isNotNull();
500+
501+
verify(
502+
postRequestedFor(anyUrl())
503+
.withHeader("Header-For-Both", equalTo("value"))
504+
.withHeader("foo", equalTo("bar")));
505+
verify(
506+
postRequestedFor(anyUrl())
507+
.withHeader("Header-For-Both", equalTo("value"))
508+
.withHeader("foot", equalTo("baz")));
509+
}
483510
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,30 @@
1515
import com.sap.ai.sdk.orchestration.model.GlobalStreamOptions;
1616
import com.sap.ai.sdk.orchestration.model.ModuleConfigs;
1717
import com.sap.ai.sdk.orchestration.model.OrchestrationConfig;
18+
import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
1819
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
1920
import io.vavr.control.Try;
21+
import java.util.ArrayList;
2022
import java.util.List;
2123
import java.util.Map;
2224
import java.util.function.Supplier;
2325
import java.util.stream.Stream;
2426
import javax.annotation.Nonnull;
27+
import lombok.RequiredArgsConstructor;
2528
import lombok.extern.slf4j.Slf4j;
2629
import lombok.val;
2730

2831
/** Client to execute requests to the orchestration service. */
2932
@Slf4j
33+
@RequiredArgsConstructor(access = lombok.AccessLevel.PRIVATE)
3034
public class OrchestrationClient {
3135
private static final String DEFAULT_SCENARIO = "orchestration";
3236
private static final String COMPLETION_ENDPOINT = "/v2/completion";
3337

3438
static final ObjectMapper JACKSON = getOrchestrationObjectMapper();
3539

3640
private final OrchestrationHttpExecutor executor;
41+
private final List<Header> customHeaders = new ArrayList<>();
3742

3843
/** Default constructor. */
3944
public OrchestrationClient() {
@@ -156,7 +161,8 @@ private static Map<String, Object> getOutputFilteringChoices(
156161
@Nonnull
157162
public CompletionPostResponse executeRequest(@Nonnull final CompletionPostRequest request)
158163
throws OrchestrationClientException {
159-
return executor.execute(COMPLETION_ENDPOINT, request, CompletionPostResponse.class);
164+
return executor.execute(
165+
COMPLETION_ENDPOINT, request, CompletionPostResponse.class, customHeaders);
160166
}
161167

162168
/**
@@ -198,7 +204,8 @@ public OrchestrationChatResponse executeRequestFromJsonModuleConfig(
198204
requestJson.set("orchestration_config", moduleConfigJson);
199205

200206
return new OrchestrationChatResponse(
201-
executor.execute(COMPLETION_ENDPOINT, requestJson, CompletionPostResponse.class));
207+
executor.execute(
208+
COMPLETION_ENDPOINT, requestJson, CompletionPostResponse.class, customHeaders));
202209
}
203210

204211
/**
@@ -214,7 +221,7 @@ public Stream<OrchestrationChatCompletionDelta> streamChatCompletionDeltas(
214221
@Nonnull final CompletionPostRequest request) throws OrchestrationClientException {
215222
request.getConfig().setStream(GlobalStreamOptions.create().enabled(true).delimiters(null));
216223

217-
return executor.stream(COMPLETION_ENDPOINT, request);
224+
return executor.stream(COMPLETION_ENDPOINT, request, customHeaders);
218225
}
219226

220227
/**
@@ -228,6 +235,24 @@ public Stream<OrchestrationChatCompletionDelta> streamChatCompletionDeltas(
228235
@Nonnull
229236
EmbeddingsPostResponse embed(@Nonnull final EmbeddingsPostRequest request)
230237
throws OrchestrationClientException {
231-
return executor.execute("/v2/embeddings", request, EmbeddingsPostResponse.class);
238+
return executor.execute("/v2/embeddings", request, EmbeddingsPostResponse.class, customHeaders);
239+
}
240+
241+
/**
242+
* Create a new orchestration client with a custom header added to every call made with this
243+
* client
244+
*
245+
* @param key the key of the custom header to add
246+
* @param value the value of the custom header to add
247+
* @return a new client.
248+
* @since 1.11.0
249+
*/
250+
@Beta
251+
@Nonnull
252+
public OrchestrationClient withHeader(@Nonnull final String key, @Nonnull final String value) {
253+
final var newClient = new OrchestrationClient(this.executor);
254+
newClient.customHeaders.addAll(this.customHeaders);
255+
newClient.customHeaders.add(new Header(key, value));
256+
return newClient;
232257
}
233258
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
import com.sap.ai.sdk.core.common.ClientResponseHandler;
1010
import com.sap.ai.sdk.core.common.ClientStreamingHandler;
1111
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
12+
import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
1213
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
1314
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationAccessException;
1415
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationNotFoundException;
1516
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.HttpClientInstantiationException;
1617
import java.io.IOException;
18+
import java.util.List;
1719
import java.util.function.Supplier;
1820
import java.util.stream.Stream;
1921
import javax.annotation.Nonnull;
@@ -39,12 +41,14 @@ class OrchestrationHttpExecutor {
3941
<T> T execute(
4042
@Nonnull final String path,
4143
@Nonnull final Object payload,
42-
@Nonnull final Class<T> responseType) {
44+
@Nonnull final Class<T> responseType,
45+
@Nonnull final List<Header> customHeaders) {
4346
try {
4447
val json = JACKSON.writeValueAsString(payload);
4548
log.debug("Successfully serialized request into JSON payload");
4649
val request = new HttpPost(path);
4750
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
51+
customHeaders.forEach(h -> request.addHeader(h.getName(), h.getValue()));
4852

4953
val client = getHttpClient();
5054

@@ -67,12 +71,15 @@ <T> T execute(
6771

6872
@Nonnull
6973
Stream<OrchestrationChatCompletionDelta> stream(
70-
@Nonnull final String path, @Nonnull final Object payload) {
74+
@Nonnull final String path,
75+
@Nonnull final Object payload,
76+
@Nonnull final List<Header> customHeaders) {
7177
try {
72-
7378
val json = JACKSON.writeValueAsString(payload);
7479
val request = new HttpPost(path);
7580
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
81+
customHeaders.forEach(h -> request.addHeader(h.getName(), h.getValue()));
82+
7683
val client = getHttpClient();
7784

7885
return new ClientStreamingHandler<>(

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
44
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
55
import static com.github.tomakehurst.wiremock.client.WireMock.badRequest;
6+
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
67
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
78
import static com.github.tomakehurst.wiremock.client.WireMock.jsonResponse;
89
import static com.github.tomakehurst.wiremock.client.WireMock.noContent;
@@ -168,6 +169,33 @@ void testCompletionError() {
168169
"Request failed with status 500 (Server Error): Internal Server Error located in Masking Module - Masking");
169170
}
170171

172+
@Test
173+
void testCustomHeaders() {
174+
stubFor(
175+
post(urlPathEqualTo("/v2/completion"))
176+
.willReturn(
177+
aResponse()
178+
.withBodyFile("templatingResponse.json")
179+
.withHeader("Content-Type", "application/json")));
180+
181+
final var clientWithHeader = client.withHeader("Header-For-Both", "value");
182+
final var result = clientWithHeader.withHeader("foo", "bar").chatCompletion(prompt, config);
183+
assertThat(result).isNotNull();
184+
185+
var streamResult =
186+
clientWithHeader.withHeader("foot", "baz").streamChatCompletion(prompt, config);
187+
assertThat(streamResult).isNotNull();
188+
189+
verify(
190+
postRequestedFor(urlPathEqualTo("/v2/completion"))
191+
.withHeader("Header-For-Both", equalTo("value"))
192+
.withHeader("foo", equalTo("bar")));
193+
verify(
194+
postRequestedFor(urlPathEqualTo("/v2/completion"))
195+
.withHeader("Header-For-Both", equalTo("value"))
196+
.withHeader("foot", equalTo("baz")));
197+
}
198+
171199
@Test
172200
void testGrounding() throws IOException {
173201
stubFor(

0 commit comments

Comments
 (0)