Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ OpenAiClient client2 = OpenAiClient.withCustomDestination(destination);

For more information, please refer to the [AI Core connectivity guide](https://sap.github.io/ai-sdk/docs/java/guides/connecting-to-ai-core) and the [SAP Cloud SDK documentation](https://sap.github.io/cloud-sdk/docs/java/features/connectivity/http-destinations).

There is also a convenient method to add custom headers to single calls through the orchestration or OpenAI client.
For more information, see the respective documentation of the [OrchestrationClient](https://sap.github.io/ai-sdk/docs/java/orchestration/chat-completion#custom-headers) and [OpenAIClient](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#custom-headers).

### _"There's a vulnerability warning `CVE-2021-41251`?"_

This is a known false-positive finding.
Expand Down
2 changes: 2 additions & 0 deletions docs/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
- Added `OpenAiChatModel`
- [Prompt Registry] [Using Prompt Registry Templates in SpringAI.](https://sap.github.io/ai-sdk/docs/java/ai-core/prompt-registry#using-templates-in-springai)
- Added `SpringAiConverter`
- [Orchestration] [Added convenience to add custom headers to individual orchestration calls.](https://sap.github.io/ai-sdk/docs/java/orchestration/chat-completion#custom-headers)
- [OpenAI] [Added convenience to add custom headers to individual LLM calls.](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#custom-headers)

### 📈 Improvements

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.Destination;
import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
import java.io.IOException;
import java.util.stream.Stream;
Expand Down Expand Up @@ -128,6 +129,37 @@ public OpenAiClient withSystemPrompt(@Nonnull final String systemPrompt) {
return this;
}

/**
* Create a new OpenAI client with a custom HTTP request header added to every call made with this
* client
*
* @param customHeader the custom header to add
* @return a new client.
* @since 1.11.0
*/
@Beta
@Nonnull
public OpenAiClient withHeader(@Nonnull final Header customHeader) {
final var newDestination =
DefaultHttpDestination.fromDestination(this.destination).header(customHeader).build();
return new OpenAiClient(newDestination);
}

/**
* Create a new OpenAI client with a custom header added to every call made with this client
*
* @param key the key of the custom header to add
* @param value the value of the custom header to add
* @return a new client.
* @since 1.11.0
*/
@Beta
@Nonnull
public OpenAiClient withHeader(@Nonnull final String key, @Nonnull final String value) {
final var customHeader = new Header(key, value);
return this.withHeader(customHeader);
}

/**
* Generate a completion for the given string prompt as user.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.sap.ai.sdk.foundationmodels.openai;

import static com.github.tomakehurst.wiremock.client.WireMock.*;
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool.ToolType.FUNCTION;
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.*;
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterSeverityResult.Severity.SAFE;
Expand All @@ -15,6 +16,7 @@
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterPromptResults;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
import java.io.IOException;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -480,4 +482,33 @@ void chatCompletionTool() {
}
""")));
}

@Test
void testCustomHeaders() {
stubForChatCompletion();
final var request =
new OpenAiChatCompletionRequest("Hello World! Why is this phrase so famous?");

var customHeader = new Header("foo", "bar");
final var result =
client.withHeader("footoo", "barzar").withHeader(customHeader).chatCompletion(request);
assertThat(result).isNotNull();

var newCustomHeader = new Header("foo", "baz");
var streamResult =
client
.withHeader("footoo", "barz")
.withHeader(newCustomHeader)
.streamChatCompletion("Hello World! Why is this phrase so famous?");
assertThat(streamResult).isNotNull();

verify(
postRequestedFor(anyUrl())
.withHeader("foo", equalTo("bar"))
.withHeader("footoo", equalTo("barzar")));
verify(
postRequestedFor(anyUrl())
.withHeader("foo", equalTo("baz"))
.withHeader("footoo", equalTo("barz")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
import com.sap.ai.sdk.orchestration.model.GlobalStreamOptions;
import com.sap.ai.sdk.orchestration.model.ModuleConfigs;
import com.sap.ai.sdk.orchestration.model.OrchestrationConfig;
import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
import io.vavr.control.Try;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.extern.slf4j.Slf4j;
import lombok.val;

Expand All @@ -34,6 +37,7 @@ public class OrchestrationClient {
static final ObjectMapper JACKSON = getOrchestrationObjectMapper();

private final OrchestrationHttpExecutor executor;
private final List<Header> customHeaders = new ArrayList<>();

/** Default constructor. */
public OrchestrationClient() {
Expand Down Expand Up @@ -62,6 +66,15 @@ public OrchestrationClient(@Nonnull final HttpDestination destination) {
this.executor = new OrchestrationHttpExecutor(() -> destination);
}

private OrchestrationClient(
@Nonnull final OrchestrationHttpExecutor executor,
@Nullable final List<Header> customHeaders) {
this.executor = executor;
if (customHeaders != null) {
this.customHeaders.addAll(customHeaders);
}
}

/**
* Convert the given prompt and config into a low-level request data object. The data object
* allows for further customization before sending the request.
Expand Down Expand Up @@ -156,7 +169,8 @@ private static Map<String, Object> getOutputFilteringChoices(
@Nonnull
public CompletionPostResponse executeRequest(@Nonnull final CompletionPostRequest request)
throws OrchestrationClientException {
return executor.execute(COMPLETION_ENDPOINT, request, CompletionPostResponse.class);
return executor.execute(
COMPLETION_ENDPOINT, request, CompletionPostResponse.class, customHeaders);
}

/**
Expand Down Expand Up @@ -198,7 +212,8 @@ public OrchestrationChatResponse executeRequestFromJsonModuleConfig(
requestJson.set("orchestration_config", moduleConfigJson);

return new OrchestrationChatResponse(
executor.execute(COMPLETION_ENDPOINT, requestJson, CompletionPostResponse.class));
executor.execute(
COMPLETION_ENDPOINT, requestJson, CompletionPostResponse.class, customHeaders));
}

/**
Expand All @@ -214,7 +229,7 @@ public Stream<OrchestrationChatCompletionDelta> streamChatCompletionDeltas(
@Nonnull final CompletionPostRequest request) throws OrchestrationClientException {
request.getConfig().setStream(GlobalStreamOptions.create().enabled(true).delimiters(null));

return executor.stream(COMPLETION_ENDPOINT, request);
return executor.stream(COMPLETION_ENDPOINT, request, customHeaders);
}

/**
Expand All @@ -228,6 +243,37 @@ public Stream<OrchestrationChatCompletionDelta> streamChatCompletionDeltas(
@Nonnull
EmbeddingsPostResponse embed(@Nonnull final EmbeddingsPostRequest request)
throws OrchestrationClientException {
return executor.execute("/v2/embeddings", request, EmbeddingsPostResponse.class);
return executor.execute("/v2/embeddings", request, EmbeddingsPostResponse.class, customHeaders);
}

/**
* Create a new orchestration client with a custom header added to every call made with this
* client
*
* @param key the key of the custom header to add
* @param value the value of the custom header to add
* @return a new client.
* @since 1.11.0
*/
@Beta
@Nonnull
public OrchestrationClient withHeader(@Nonnull final String key, @Nonnull final String value) {
return this.withHeader(new Header(key, value));
}

/**
* Create a new orchestration client with a custom header added to every call made with this
* client
*
* @param customHeader the custom header to add
* @return a new client.
* @since 1.11.0
*/
@Beta
@Nonnull
public OrchestrationClient withHeader(@Nonnull final Header customHeader) {
final var newClient = new OrchestrationClient(this.executor, this.customHeaders);
newClient.customHeaders.add(customHeader);
return newClient;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
import com.sap.ai.sdk.core.common.ClientResponseHandler;
import com.sap.ai.sdk.core.common.ClientStreamingHandler;
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationAccessException;
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationNotFoundException;
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.HttpClientInstantiationException;
import java.io.IOException;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
Expand All @@ -39,14 +42,15 @@ class OrchestrationHttpExecutor {
<T> T execute(
@Nonnull final String path,
@Nonnull final Object payload,
@Nonnull final Class<T> responseType) {
@Nonnull final Class<T> responseType,
@Nonnull final List<Header> customHeaders) {
try {
val json = JACKSON.writeValueAsString(payload);
log.debug("Successfully serialized request into JSON payload");
val request = new HttpPost(path);
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));

val client = getHttpClient();
val client = getHttpClient(customHeaders);

val handler =
new ClientResponseHandler<>(responseType, OrchestrationError.Synchronous.class, FACTORY)
Expand All @@ -67,13 +71,15 @@ <T> T execute(

@Nonnull
Stream<OrchestrationChatCompletionDelta> stream(
@Nonnull final String path, @Nonnull final Object payload) {
@Nonnull final String path,
@Nonnull final Object payload,
@Nonnull final List<Header> customHeaders) {
try {

val json = JACKSON.writeValueAsString(payload);
val request = new HttpPost(path);
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
val client = getHttpClient();
val client = getHttpClient(customHeaders);

return new ClientStreamingHandler<>(
OrchestrationChatCompletionDelta.class, OrchestrationError.Streaming.class, FACTORY)
Expand All @@ -90,8 +96,11 @@ Stream<OrchestrationChatCompletionDelta> stream(
}

@Nonnull
private HttpClient getHttpClient() {
val destination = destinationSupplier.get();
private HttpClient getHttpClient(@Nonnull final List<Header> customHeaders) {
val destination =
DefaultHttpDestination.fromDestination(destinationSupplier.get())
.headers(customHeaders)
.build();
log.debug("Using destination {} to connect to orchestration service", destination);
return ApacheHttpClient5Accessor.getHttpClient(destination);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
import static com.github.tomakehurst.wiremock.client.WireMock.badRequest;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
import static com.github.tomakehurst.wiremock.client.WireMock.jsonResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.noContent;
Expand Down Expand Up @@ -77,6 +78,7 @@
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache;
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigDecimal;
Expand Down Expand Up @@ -168,6 +170,41 @@ void testCompletionError() {
"Request failed with status 500 (Server Error): Internal Server Error located in Masking Module - Masking");
}

@Test
void testCustomHeaders() {
stubFor(
post(urlPathEqualTo("/v2/completion"))
.willReturn(
aResponse()
.withBodyFile("templatingResponse.json")
.withHeader("Content-Type", "application/json")));

var customHeader = new Header("foo", "bar");
final var result =
client
.withHeader("footoo", "barzar")
.withHeader(customHeader)
.chatCompletion(prompt, config);
assertThat(result).isNotNull();

var newCustomHeader = new Header("foo", "baz");
var streamResult =
client
.withHeader("footoo", "barz")
.withHeader(newCustomHeader)
.streamChatCompletion(prompt, config);
assertThat(streamResult).isNotNull();

verify(
postRequestedFor(urlPathEqualTo("/v2/completion"))
.withHeader("foo", equalTo("bar"))
.withHeader("footoo", equalTo("barzar")));
verify(
postRequestedFor(urlPathEqualTo("/v2/completion"))
.withHeader("foo", equalTo("baz"))
.withHeader("footoo", equalTo("barz")));
}

@Test
void testGrounding() throws IOException {
stubFor(
Expand Down