Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ For SAP internal development, you can also use `SNAPSHOT` builds from the [inter

The AI SDK leverages the destination concept from the SAP Cloud SDK to manage the connection to AI Core.
This opens up a wide range of possibilities to customize the connection, including adding custom headers.
The following shows how to add custom headers to all requests sent to AI Core.

```java
var service = new AiCoreService();
Expand All @@ -170,15 +171,19 @@ var destination =
// AI Core client
service = service.withBaseDestination(destination);
DeploymentApi client = new DeploymentApi(service);
```

For more information, please refer to the [AI Core connectivity guide](https://sap.github.io/ai-sdk/docs/java/guides/connecting-to-ai-core) and the [SAP Cloud SDK documentation](https://sap.github.io/cloud-sdk/docs/java/features/connectivity/http-destinations).

// Orchestration client
OrchestrationClient client = new OrchestrationClient(destination);
There is also a convenient method to add custom headers to single calls through the Orchestration or OpenAI client.

// OpenAI client
OpenAiClient client2 = OpenAiClient.withCustomDestination(destination);
```java
var client = new OrchestrationClient();

var result = client.withHeader("my-header-key", "my-header-value").chatCompletion(prompt, config);
```

For more information, please refer to the [AI Core connectivity guide](https://sap.github.io/ai-sdk/docs/java/guides/connecting-to-ai-core) and the [SAP Cloud SDK documentation](https://sap.github.io/cloud-sdk/docs/java/features/connectivity/http-destinations).
For more information on this feature, see the respective documentation of the [OrchestrationClient](https://sap.github.io/ai-sdk/docs/java/orchestration/chat-completion#custom-headers) and [OpenAIClient](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#custom-headers).

### _"There's a vulnerability warning `CVE-2021-41251`?"_

Expand Down
2 changes: 2 additions & 0 deletions docs/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
- Added `OpenAiChatModel`
- [Prompt Registry] [Using Prompt Registry Templates in SpringAI.](https://sap.github.io/ai-sdk/docs/java/ai-core/prompt-registry#using-templates-in-springai)
- Added `SpringAiConverter`
- [Orchestration] [Added convenience to add custom headers to individual orchestration calls.](https://sap.github.io/ai-sdk/docs/java/orchestration/chat-completion#custom-headers)
- [OpenAI] [Added convenience to add custom headers to individual LLM calls.](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#custom-headers)

### 📈 Improvements

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.Destination;
import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
Expand All @@ -49,6 +52,7 @@ public final class OpenAiClient {
@Nullable private String systemPrompt = null;

@Nonnull private final Destination destination;
@Nonnull private final List<Header> customHeaders = new ArrayList<>();

/**
* Create a new OpenAI client for the given foundation model, using the default resource group.
Expand Down Expand Up @@ -127,6 +131,23 @@ public OpenAiClient withSystemPrompt(@Nonnull final String systemPrompt) {
return this;
}

/**
* Create a new OpenAI client with a custom header added to every call made with this client
*
* @param key the key of the custom header to add
* @param value the value of the custom header to add
* @return a new client.
* @since 1.11.0
*/
@Beta
@Nonnull
public OpenAiClient withHeader(@Nonnull final String key, @Nonnull final String value) {
final var newClient = new OpenAiClient(this.destination);
newClient.customHeaders.addAll(this.customHeaders);
newClient.customHeaders.add(new Header(key, value));
return newClient;
}

/**
* Generate a completion for the given string prompt as user.
*
Expand Down Expand Up @@ -395,7 +416,7 @@ private <T> T execute(
@Nonnull final Object payload,
@Nonnull final Class<T> responseType) {
final var request = new HttpPost(path);
serializeAndSetHttpEntity(request, payload);
serializeAndSetHttpEntity(request, payload, this.customHeaders);
return executeRequest(request, responseType);
}

Expand All @@ -405,15 +426,18 @@ private <D extends StreamedDelta> Stream<D> executeStream(
@Nonnull final Object payload,
@Nonnull final Class<D> deltaType) {
final var request = new HttpPost(path);
serializeAndSetHttpEntity(request, payload);
serializeAndSetHttpEntity(request, payload, this.customHeaders);
return streamRequest(request, deltaType);
}

private static void serializeAndSetHttpEntity(
@Nonnull final BasicClassicHttpRequest request, @Nonnull final Object payload) {
@Nonnull final BasicClassicHttpRequest request,
@Nonnull final Object payload,
@Nonnull final List<Header> customHeaders) {
try {
final var json = JACKSON.writeValueAsString(payload);
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
customHeaders.forEach(h -> request.addHeader(h.getName(), h.getValue()));
} catch (final JsonProcessingException e) {
throw new OpenAiClientException("Failed to serialize request parameters", e)
.setHttpRequest(request);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.sap.ai.sdk.foundationmodels.openai;

import static com.github.tomakehurst.wiremock.client.WireMock.*;
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool.ToolType.FUNCTION;
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.*;
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterSeverityResult.Severity.SAFE;
Expand Down Expand Up @@ -480,4 +481,30 @@ void chatCompletionTool() {
}
""")));
}

@Test
void testCustomHeaders() {
stubForChatCompletion();
final var request =
new OpenAiChatCompletionRequest("Hello World! Why is this phrase so famous?");
final var clientWithHeader = client.withHeader("Header-For-Both", "value");

final var result = clientWithHeader.withHeader("foo", "bar").chatCompletion(request);
assertThat(result).isNotNull();

var streamResult =
clientWithHeader
.withHeader("foot", "baz")
.streamChatCompletion("Hello World! Why is this phrase so famous?");
assertThat(streamResult).isNotNull();

verify(
postRequestedFor(anyUrl())
.withHeader("Header-For-Both", equalTo("value"))
.withHeader("foo", equalTo("bar")));
verify(
postRequestedFor(anyUrl())
.withHeader("Header-For-Both", equalTo("value"))
.withHeader("foot", equalTo("baz")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,30 @@
import com.sap.ai.sdk.orchestration.model.GlobalStreamOptions;
import com.sap.ai.sdk.orchestration.model.ModuleConfigs;
import com.sap.ai.sdk.orchestration.model.OrchestrationConfig;
import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
import io.vavr.control.Try;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;

/** Client to execute requests to the orchestration service. */
@Slf4j
@RequiredArgsConstructor(access = lombok.AccessLevel.PRIVATE)
public class OrchestrationClient {
private static final String DEFAULT_SCENARIO = "orchestration";
private static final String COMPLETION_ENDPOINT = "/v2/completion";

static final ObjectMapper JACKSON = getOrchestrationObjectMapper();

private final OrchestrationHttpExecutor executor;
private final List<Header> customHeaders = new ArrayList<>();

/** Default constructor. */
public OrchestrationClient() {
Expand Down Expand Up @@ -156,7 +161,8 @@ private static Map<String, Object> getOutputFilteringChoices(
@Nonnull
public CompletionPostResponse executeRequest(@Nonnull final CompletionPostRequest request)
throws OrchestrationClientException {
return executor.execute(COMPLETION_ENDPOINT, request, CompletionPostResponse.class);
return executor.execute(
COMPLETION_ENDPOINT, request, CompletionPostResponse.class, customHeaders);
}

/**
Expand Down Expand Up @@ -198,7 +204,8 @@ public OrchestrationChatResponse executeRequestFromJsonModuleConfig(
requestJson.set("orchestration_config", moduleConfigJson);

return new OrchestrationChatResponse(
executor.execute(COMPLETION_ENDPOINT, requestJson, CompletionPostResponse.class));
executor.execute(
COMPLETION_ENDPOINT, requestJson, CompletionPostResponse.class, customHeaders));
}

/**
Expand All @@ -214,7 +221,7 @@ public Stream<OrchestrationChatCompletionDelta> streamChatCompletionDeltas(
@Nonnull final CompletionPostRequest request) throws OrchestrationClientException {
request.getConfig().setStream(GlobalStreamOptions.create().enabled(true).delimiters(null));

return executor.stream(COMPLETION_ENDPOINT, request);
return executor.stream(COMPLETION_ENDPOINT, request, customHeaders);
}

/**
Expand All @@ -228,6 +235,24 @@ public Stream<OrchestrationChatCompletionDelta> streamChatCompletionDeltas(
@Nonnull
EmbeddingsPostResponse embed(@Nonnull final EmbeddingsPostRequest request)
throws OrchestrationClientException {
return executor.execute("/v2/embeddings", request, EmbeddingsPostResponse.class);
return executor.execute("/v2/embeddings", request, EmbeddingsPostResponse.class, customHeaders);
}

/**
* Create a new orchestration client with a custom header added to every call made with this
* client
*
* @param key the key of the custom header to add
* @param value the value of the custom header to add
* @return a new client.
* @since 1.11.0
*/
@Beta
@Nonnull
public OrchestrationClient withHeader(@Nonnull final String key, @Nonnull final String value) {
final var newClient = new OrchestrationClient(this.executor);
newClient.customHeaders.addAll(this.customHeaders);
newClient.customHeaders.add(new Header(key, value));
return newClient;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import com.sap.ai.sdk.core.common.ClientResponseHandler;
import com.sap.ai.sdk.core.common.ClientStreamingHandler;
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationAccessException;
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationNotFoundException;
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.HttpClientInstantiationException;
import java.io.IOException;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
Expand All @@ -39,12 +41,14 @@ class OrchestrationHttpExecutor {
<T> T execute(
@Nonnull final String path,
@Nonnull final Object payload,
@Nonnull final Class<T> responseType) {
@Nonnull final Class<T> responseType,
@Nonnull final List<Header> customHeaders) {
try {
val json = JACKSON.writeValueAsString(payload);
log.debug("Successfully serialized request into JSON payload");
val request = new HttpPost(path);
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
customHeaders.forEach(h -> request.addHeader(h.getName(), h.getValue()));

val client = getHttpClient();

Expand All @@ -67,12 +71,15 @@ <T> T execute(

@Nonnull
Stream<OrchestrationChatCompletionDelta> stream(
@Nonnull final String path, @Nonnull final Object payload) {
@Nonnull final String path,
@Nonnull final Object payload,
@Nonnull final List<Header> customHeaders) {
try {

val json = JACKSON.writeValueAsString(payload);
val request = new HttpPost(path);
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
customHeaders.forEach(h -> request.addHeader(h.getName(), h.getValue()));

val client = getHttpClient();

return new ClientStreamingHandler<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
import static com.github.tomakehurst.wiremock.client.WireMock.badRequest;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
import static com.github.tomakehurst.wiremock.client.WireMock.jsonResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.noContent;
Expand Down Expand Up @@ -168,6 +169,33 @@ void testCompletionError() {
"Request failed with status 500 (Server Error): Internal Server Error located in Masking Module - Masking");
}

@Test
void testCustomHeaders() {
stubFor(
post(urlPathEqualTo("/v2/completion"))
.willReturn(
aResponse()
.withBodyFile("templatingResponse.json")
.withHeader("Content-Type", "application/json")));

final var clientWithHeader = client.withHeader("Header-For-Both", "value");
final var result = clientWithHeader.withHeader("foo", "bar").chatCompletion(prompt, config);
assertThat(result).isNotNull();

var streamResult =
clientWithHeader.withHeader("foot", "baz").streamChatCompletion(prompt, config);
assertThat(streamResult).isNotNull();

verify(
postRequestedFor(urlPathEqualTo("/v2/completion"))
.withHeader("Header-For-Both", equalTo("value"))
.withHeader("foo", equalTo("bar")));
verify(
postRequestedFor(urlPathEqualTo("/v2/completion"))
.withHeader("Header-For-Both", equalTo("value"))
.withHeader("foot", equalTo("baz")));
}

@Test
void testGrounding() throws IOException {
stubFor(
Expand Down
Loading