Skip to content

Commit 8fd1de1

Browse files
committed
feat: Add embeddings endpoint and refactor orchestration client
1 parent e327c1b commit 8fd1de1

17 files changed

+3541
-95
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 21 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,18 @@
88
import com.fasterxml.jackson.databind.node.ObjectNode;
99
import com.google.common.annotations.Beta;
1010
import com.sap.ai.sdk.core.AiCoreService;
11-
import com.sap.ai.sdk.core.DeploymentResolutionException;
12-
import com.sap.ai.sdk.core.common.ClientResponseHandler;
13-
import com.sap.ai.sdk.core.common.ClientStreamingHandler;
14-
import com.sap.ai.sdk.core.common.StreamedDelta;
1511
import com.sap.ai.sdk.orchestration.model.CompletionPostRequest;
1612
import com.sap.ai.sdk.orchestration.model.CompletionPostResponseSynchronous;
13+
import com.sap.ai.sdk.orchestration.model.EmbeddingsPostRequest;
14+
import com.sap.ai.sdk.orchestration.model.EmbeddingsPostResponse;
1715
import com.sap.ai.sdk.orchestration.model.ModuleConfigs;
1816
import com.sap.ai.sdk.orchestration.model.OrchestrationConfig;
19-
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
2017
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
21-
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationAccessException;
22-
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationNotFoundException;
23-
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.HttpClientInstantiationException;
24-
import java.io.IOException;
2518
import java.util.function.Supplier;
2619
import java.util.stream.Stream;
2720
import javax.annotation.Nonnull;
2821
import lombok.extern.slf4j.Slf4j;
2922
import lombok.val;
30-
import org.apache.hc.client5.http.classic.methods.HttpPost;
31-
import org.apache.hc.core5.http.ContentType;
32-
import org.apache.hc.core5.http.io.entity.StringEntity;
33-
import org.apache.hc.core5.http.message.BasicClassicHttpRequest;
3423

3524
/** Client to execute requests to the orchestration service. */
3625
@Slf4j
@@ -39,12 +28,13 @@ public class OrchestrationClient {
3928

4029
static final ObjectMapper JACKSON = getOrchestrationObjectMapper();
4130

42-
@Nonnull private final Supplier<HttpDestination> destinationSupplier;
31+
private final OrchestrationHttpExecutor executor;
4332

4433
/** Default constructor. */
4534
public OrchestrationClient() {
46-
destinationSupplier =
35+
final Supplier<HttpDestination> destinationSupplier =
4736
() -> new AiCoreService().getInferenceDestination().forScenario(DEFAULT_SCENARIO);
37+
this.executor = new OrchestrationHttpExecutor(destinationSupplier);
4838
}
4939

5040
/**
@@ -64,7 +54,7 @@ public OrchestrationClient() {
6454
*/
6555
@Beta
6656
public OrchestrationClient(@Nonnull final HttpDestination destination) {
67-
this.destinationSupplier = () -> destination;
57+
this.executor = new OrchestrationHttpExecutor(() -> destination);
6858
}
6959

7060
/**
@@ -150,15 +140,7 @@ private static void throwOnContentFilter(@Nonnull final OrchestrationChatComplet
150140
@Nonnull
151141
public CompletionPostResponseSynchronous executeRequest(
152142
@Nonnull final CompletionPostRequest request) throws OrchestrationClientException {
153-
final String jsonRequest;
154-
try {
155-
jsonRequest = JACKSON.writeValueAsString(request);
156-
log.debug("Serialized request into JSON payload: {}", jsonRequest);
157-
} catch (final JsonProcessingException e) {
158-
throw new OrchestrationClientException("Failed to serialize request parameters", e);
159-
}
160-
161-
return executeRequest(jsonRequest);
143+
return executor.execute("/completion", request, CompletionPostResponseSynchronous.class);
162144
}
163145

164146
/**
@@ -199,38 +181,8 @@ public OrchestrationChatResponse executeRequestFromJsonModuleConfig(
199181
}
200182
requestJson.set("orchestration_config", moduleConfigJson);
201183

202-
final String body;
203-
try {
204-
body = JACKSON.writeValueAsString(requestJson);
205-
} catch (JsonProcessingException e) {
206-
throw new OrchestrationClientException("Failed to serialize request to JSON", e);
207-
}
208-
return new OrchestrationChatResponse(executeRequest(body));
209-
}
210-
211-
@Nonnull
212-
CompletionPostResponseSynchronous executeRequest(@Nonnull final String request) {
213-
val postRequest = new HttpPost("/completion");
214-
postRequest.setEntity(new StringEntity(request, ContentType.APPLICATION_JSON));
215-
216-
try {
217-
val destination = destinationSupplier.get();
218-
log.debug("Using destination {} to connect to orchestration service", destination);
219-
val client = ApacheHttpClient5Accessor.getHttpClient(destination);
220-
val handler =
221-
new ClientResponseHandler<>(
222-
CompletionPostResponseSynchronous.class,
223-
OrchestrationError.class,
224-
OrchestrationClientException::new)
225-
.objectMapper(JACKSON);
226-
return client.execute(postRequest, handler);
227-
} catch (DeploymentResolutionException
228-
| DestinationAccessException
229-
| DestinationNotFoundException
230-
| HttpClientInstantiationException
231-
| IOException e) {
232-
throw new OrchestrationClientException("Failed to execute request", e);
233-
}
184+
return new OrchestrationChatResponse(
185+
executor.execute("/completion", requestJson, CompletionPostResponseSynchronous.class));
234186
}
235187

236188
/**
@@ -245,42 +197,20 @@ CompletionPostResponseSynchronous executeRequest(@Nonnull final String request)
245197
public Stream<OrchestrationChatCompletionDelta> streamChatCompletionDeltas(
246198
@Nonnull final CompletionPostRequest request) throws OrchestrationClientException {
247199
request.getOrchestrationConfig().setStream(true);
248-
return executeStream("/completion", request, OrchestrationChatCompletionDelta.class);
249-
}
250-
251-
@Nonnull
252-
private <D extends StreamedDelta> Stream<D> executeStream(
253-
@Nonnull final String path,
254-
@Nonnull final Object payload,
255-
@Nonnull final Class<D> deltaType) {
256-
final var request = new HttpPost(path);
257-
serializeAndSetHttpEntity(request, payload);
258-
return streamRequest(request, deltaType);
259-
}
260-
261-
private static void serializeAndSetHttpEntity(
262-
@Nonnull final BasicClassicHttpRequest request, @Nonnull final Object payload) {
263-
try {
264-
final var json = JACKSON.writeValueAsString(payload);
265-
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
266-
} catch (final JsonProcessingException e) {
267-
throw new OrchestrationClientException("Failed to serialize request parameters", e);
268-
}
200+
return executor.stream(request);
269201
}
270202

203+
/**
204+
* Generate embeddings for the given request.
205+
*
206+
* @param request the request containing the input text and other parameters.
207+
* @return the response containing the embeddings.
208+
* @throws OrchestrationClientException if the request fails
209+
* @since 1.9.0
210+
*/
271211
@Nonnull
272-
private <D extends StreamedDelta> Stream<D> streamRequest(
273-
final BasicClassicHttpRequest request, @Nonnull final Class<D> deltaType) {
274-
try {
275-
val destination = destinationSupplier.get();
276-
log.debug("Using destination {} to connect to orchestration service", destination);
277-
val client = ApacheHttpClient5Accessor.getHttpClient(destination);
278-
return new ClientStreamingHandler<>(
279-
deltaType, OrchestrationError.class, OrchestrationClientException::new)
280-
.objectMapper(JACKSON)
281-
.handleStreamingResponse(client.executeOpen(null, request, null));
282-
} catch (final IOException e) {
283-
throw new OrchestrationClientException("Request to the Orchestration service failed", e);
284-
}
212+
public EmbeddingsPostResponse embed(@Nonnull final EmbeddingsPostRequest request)
213+
throws OrchestrationClientException {
214+
return executor.execute("/v2/embeddings", request, EmbeddingsPostResponse.class);
285215
}
286216
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static com.sap.ai.sdk.orchestration.OrchestrationJacksonConfiguration.getOrchestrationObjectMapper;
4+
5+
import com.fasterxml.jackson.core.JsonProcessingException;
6+
import com.fasterxml.jackson.databind.ObjectMapper;
7+
import com.sap.ai.sdk.core.DeploymentResolutionException;
8+
import com.sap.ai.sdk.core.common.ClientResponseHandler;
9+
import com.sap.ai.sdk.core.common.ClientStreamingHandler;
10+
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
11+
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
12+
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationAccessException;
13+
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationNotFoundException;
14+
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.HttpClientInstantiationException;
15+
import java.io.IOException;
16+
import java.util.function.Supplier;
17+
import java.util.stream.Stream;
18+
import javax.annotation.Nonnull;
19+
import lombok.extern.slf4j.Slf4j;
20+
import lombok.val;
21+
import org.apache.hc.client5.http.classic.HttpClient;
22+
import org.apache.hc.client5.http.classic.methods.HttpPost;
23+
import org.apache.hc.core5.http.ContentType;
24+
import org.apache.hc.core5.http.io.entity.StringEntity;
25+
26+
@Slf4j
27+
class OrchestrationHttpExecutor {
28+
private final Supplier<HttpDestination> destinationSupplier;
29+
private static final ObjectMapper JACKSON = getOrchestrationObjectMapper();
30+
31+
OrchestrationHttpExecutor(@Nonnull final Supplier<HttpDestination> destinationSupplier)
32+
throws OrchestrationClientException {
33+
this.destinationSupplier = destinationSupplier;
34+
}
35+
36+
@Nonnull
37+
<T> T execute(
38+
@Nonnull final String path,
39+
@Nonnull final Object payload,
40+
@Nonnull final Class<T> responseType) {
41+
try {
42+
val json = JACKSON.writeValueAsString(payload);
43+
log.debug("Serialized request into JSON payload: {}", json);
44+
val request = new HttpPost(path);
45+
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
46+
47+
val client = getHttpClient();
48+
49+
val handler =
50+
new ClientResponseHandler<>(
51+
responseType, OrchestrationError.class, OrchestrationClientException::new)
52+
.objectMapper(JACKSON);
53+
return client.execute(request, handler);
54+
55+
} catch (JsonProcessingException e) {
56+
throw new OrchestrationClientException("Failed to serialize request payload for " + path, e);
57+
} catch (DeploymentResolutionException
58+
| DestinationAccessException
59+
| DestinationNotFoundException
60+
| HttpClientInstantiationException
61+
| IOException e) {
62+
throw new OrchestrationClientException(
63+
"Request to Orchestration service failed for " + path, e);
64+
}
65+
}
66+
67+
@Nonnull
68+
Stream<OrchestrationChatCompletionDelta> stream(@Nonnull final Object payload) {
69+
try {
70+
71+
val json = JACKSON.writeValueAsString(payload);
72+
val request = new HttpPost("/completion");
73+
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
74+
val client = getHttpClient();
75+
76+
return new ClientStreamingHandler<>(
77+
OrchestrationChatCompletionDelta.class,
78+
OrchestrationError.class,
79+
OrchestrationClientException::new)
80+
.objectMapper(JACKSON)
81+
.handleStreamingResponse(client.executeOpen(null, request, null));
82+
83+
} catch (JsonProcessingException e) {
84+
throw new OrchestrationClientException(
85+
"Failed to serialize payload for streaming request", e);
86+
} catch (IOException e) {
87+
throw new OrchestrationClientException(
88+
"Streaming request to the Orchestration service failed", e);
89+
}
90+
}
91+
92+
@Nonnull
93+
private HttpClient getHttpClient() {
94+
val destination = destinationSupplier.get();
95+
log.debug("Using destination {} to connect to orchestration service", destination);
96+
return ApacheHttpClient5Accessor.getHttpClient(destination);
97+
}
98+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Internal Orchestration Service API
3+
* Orchestration is an inference service which provides common additional capabilities for business AI scenarios, such as content filtering and data masking. At the core of the service is the LLM module which allows for an easy, harmonized access to the language models of gen AI hub. The service is designed to be modular and extensible, allowing for the addition of new modules in the future. Each module can be configured independently and at runtime, allowing for a high degree of flexibility in the orchestration of AI services.
4+
*
5+
*
6+
*
7+
* NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
8+
* https://openapi-generator.tech
9+
* Do not edit the class manually.
10+
*/
11+
12+
package com.sap.ai.sdk.orchestration.model;
13+
14+
import java.math.BigDecimal;
15+
import java.util.List;
16+
import javax.annotation.Nonnull;
17+
18+
/** Embedding */
19+
public interface Embedding {
20+
/** Helper class to create a String that implements {@link Embedding}. */
21+
record InnerString(@com.fasterxml.jackson.annotation.JsonValue @Nonnull String value)
22+
implements Embedding {}
23+
24+
/**
25+
* Creator to enable deserialization of a String.
26+
*
27+
* @param val the value to use
28+
* @return a new instance of {@link InnerString}.
29+
*/
30+
@com.fasterxml.jackson.annotation.JsonCreator
31+
@Nonnull
32+
static InnerString create(@Nonnull final String val) {
33+
return new InnerString(val);
34+
}
35+
36+
/** Helper class to create a list of BigDecimal that implements {@link Embedding}. */
37+
record InnerBigDecimals(
38+
@com.fasterxml.jackson.annotation.JsonValue @Nonnull List<BigDecimal> values)
39+
implements Embedding {}
40+
41+
/**
42+
* Creator to enable deserialization of a list of BigDecimal.
43+
*
44+
* @param val the value to use
45+
* @return a new instance of {@link InnerBigDecimals}.
46+
*/
47+
@com.fasterxml.jackson.annotation.JsonCreator
48+
@Nonnull
49+
static InnerBigDecimals create(@Nonnull final List<BigDecimal> val) {
50+
return new InnerBigDecimals(val);
51+
}
52+
}

0 commit comments

Comments
 (0)