Skip to content

Commit c6b9421

Browse files
authored
feat: [Orchestration] SpringAi Embedding Integration (#621)
* Generate new model classes for float[] * Regenerate for all modules * Add release notes * Add EmbeddingModel Add Unit and E2E test * Release notes * Add review suggestions * Updatw @SInCE to mark for next release * pointless ci failure fix
1 parent fa5aa6d commit c6b9421

File tree

11 files changed

+327
-10
lines changed

11 files changed

+327
-10
lines changed

docs/release_notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
### ✨ New Functionality
1414

15-
-
15+
- [Orchestration] Introduced Spring AI integration for embeddings generation with the new `OrchestrationSpringAiEmbeddingModel` class.
1616

1717
### 📈 Improvements
1818

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiSpringEmbeddingModel.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
66
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
77
import java.util.Objects;
8+
import java.util.Optional;
89
import java.util.stream.IntStream;
910
import javax.annotation.Nonnull;
1011
import org.springframework.ai.chat.metadata.DefaultUsage;
1112
import org.springframework.ai.document.Document;
1213
import org.springframework.ai.document.MetadataMode;
1314
import org.springframework.ai.embedding.Embedding;
1415
import org.springframework.ai.embedding.EmbeddingModel;
16+
import org.springframework.ai.embedding.EmbeddingOptions;
1517
import org.springframework.ai.embedding.EmbeddingRequest;
1618
import org.springframework.ai.embedding.EmbeddingResponse;
1719
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
@@ -26,8 +28,8 @@
2628
*/
2729
public class OpenAiSpringEmbeddingModel implements EmbeddingModel {
2830

29-
private final OpenAiClient client;
30-
private final MetadataMode metadataMode;
31+
@Nonnull private final OpenAiClient client;
32+
@Nonnull private final MetadataMode metadataMode;
3133

3234
/**
3335
* Constructs an {@code OpenAiSpringEmbeddingModel} with the specified {@link OpenAiClient} of
@@ -65,12 +67,6 @@ public OpenAiSpringEmbeddingModel(
6567
@Nonnull
6668
public EmbeddingResponse call(@Nonnull final EmbeddingRequest request)
6769
throws IllegalArgumentException {
68-
69-
if (request.getOptions().getModel() != null) {
70-
throw new IllegalArgumentException(
71-
"Do not set a model in EmbeddingOptions, as the OpenAiClient already defines the model.");
72-
}
73-
7470
final var openAiRequest = createEmbeddingsCreateRequest(request);
7571
final var openAiResponse = client.embedding(openAiRequest);
7672

@@ -88,8 +84,15 @@ public float[] embed(@Nonnull final Document document) throws UnsupportedOperati
8884

8985
private EmbeddingsCreateRequest createEmbeddingsCreateRequest(
9086
@Nonnull final EmbeddingRequest request) {
87+
88+
final var options = Optional.ofNullable(request.getOptions());
89+
if (options.map(EmbeddingOptions::getModel).isPresent()) {
90+
throw new IllegalArgumentException(
91+
"Do not set a model in EmbeddingOptions, as the OpenAiClient already defines the model.");
92+
}
93+
9194
return new EmbeddingsCreateRequest()
92-
.dimensions(request.getOptions().getDimensions())
95+
.dimensions(options.map(EmbeddingOptions::getDimensions).orElse(null))
9396
.input(EmbeddingsCreateRequestInput.createListOfStrings(request.getInstructions()));
9497
}
9598

orchestration/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@
6464
<artifactId>spring-ai-model</artifactId>
6565
<optional>true</optional>
6666
</dependency>
67+
<dependency>
68+
<groupId>org.springframework.ai</groupId>
69+
<artifactId>spring-ai-commons</artifactId>
70+
<optional>true</optional>
71+
</dependency>
6772
<dependency>
6873
<groupId>io.projectreactor</groupId>
6974
<artifactId>reactor-core</artifactId>
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package com.sap.ai.sdk.orchestration.spring;
2+
3+
import com.google.common.annotations.Beta;
4+
import com.sap.ai.sdk.orchestration.OrchestrationClient;
5+
import com.sap.ai.sdk.orchestration.OrchestrationEmbeddingModel;
6+
import com.sap.ai.sdk.orchestration.OrchestrationEmbeddingRequest;
7+
import com.sap.ai.sdk.orchestration.OrchestrationEmbeddingResponse;
8+
import java.util.List;
9+
import java.util.Objects;
10+
import java.util.stream.IntStream;
11+
import javax.annotation.Nonnull;
12+
import lombok.RequiredArgsConstructor;
13+
import org.springframework.ai.chat.metadata.DefaultUsage;
14+
import org.springframework.ai.document.Document;
15+
import org.springframework.ai.document.MetadataMode;
16+
import org.springframework.ai.embedding.Embedding;
17+
import org.springframework.ai.embedding.EmbeddingModel;
18+
import org.springframework.ai.embedding.EmbeddingOptions;
19+
import org.springframework.ai.embedding.EmbeddingRequest;
20+
import org.springframework.ai.embedding.EmbeddingResponse;
21+
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
22+
23+
/**
24+
* A Spring-based implementation of the {@link EmbeddingModel} interface that integrates with the
25+
* Orchestration SDK to provide embedding functionality.
26+
*
27+
* @since 1.13.0
28+
*/
29+
@Beta
30+
@RequiredArgsConstructor
31+
public class OrchestrationSpringEmbeddingModel implements EmbeddingModel {
32+
33+
/**
34+
* Default embedding options to provide model name and other parameters.
35+
*
36+
* <p>Can be overridden by options in the request.
37+
*
38+
* @see OrchestrationSpringEmbeddingModel#call(EmbeddingRequest)
39+
*/
40+
@Nonnull private final EmbeddingOptions defaultOptions;
41+
42+
/** Client for interacting with the Orchestration SDK. */
43+
@Nonnull private final OrchestrationClient client;
44+
45+
/** Metadata mode to determine how document metadata is handled. */
46+
@Nonnull private final MetadataMode metadataMode;
47+
48+
/**
49+
* Constructs an instance with default options, a new {@link OrchestrationClient}, and sets the
50+
* metadata mode to {@link MetadataMode#EMBED}.
51+
*
52+
* @param defaultOptions Default embedding options.
53+
*/
54+
public OrchestrationSpringEmbeddingModel(@Nonnull final EmbeddingOptions defaultOptions) {
55+
this(defaultOptions, new OrchestrationClient(), MetadataMode.EMBED);
56+
}
57+
58+
/**
59+
* Calls the embedding model with the given request and returns the response.
60+
*
61+
* <p>Note: The request's options takes precedence over the defaultOptions.
62+
*
63+
* @param request The embedding request containing input texts and options.
64+
* @return The embedding response containing results and metadata.
65+
*/
66+
@Override
67+
@Nonnull
68+
public EmbeddingResponse call(@Nonnull final EmbeddingRequest request) {
69+
final var orchestrationRequest = createOrchestrationEmbeddingRequest(request);
70+
final var orchestrationResponse = client.embed(orchestrationRequest);
71+
return createSpringAiEmbeddingResponse(orchestrationResponse);
72+
}
73+
74+
@Override
75+
@Nonnull
76+
public float[] embed(@Nonnull final Document document) {
77+
return embed(document.getFormattedContent(this.metadataMode));
78+
}
79+
80+
@Override
81+
@Nonnull
82+
public List<float[]> embed(@Nonnull final List<String> texts) {
83+
// Propagate defaultOptions instead of incomplete options in default method implementation
84+
final var response = this.call(new EmbeddingRequest(texts, this.defaultOptions));
85+
return response.getResults().stream().map(Embedding::getOutput).toList();
86+
}
87+
88+
@Nonnull
89+
private OrchestrationEmbeddingRequest createOrchestrationEmbeddingRequest(
90+
@Nonnull final EmbeddingRequest request) {
91+
final var options = Objects.requireNonNullElse(request.getOptions(), defaultOptions);
92+
final var modelName =
93+
Objects.requireNonNull(options.getModel(), "EmbeddingOptions must provide the model name");
94+
final var model =
95+
new OrchestrationEmbeddingModel(modelName).withDimensions(options.getDimensions());
96+
return OrchestrationEmbeddingRequest.forModel(model).forInputs(request.getInstructions());
97+
}
98+
99+
@Nonnull
100+
private EmbeddingResponse createSpringAiEmbeddingResponse(
101+
@Nonnull final OrchestrationEmbeddingResponse response) {
102+
final var embeddings =
103+
IntStream.range(0, response.getEmbeddingVectors().size())
104+
.mapToObj(i -> new Embedding(response.getEmbeddingVectors().get(i), i))
105+
.toList();
106+
final var finalResult = response.getOriginalResponse().getFinalResult();
107+
final var orchUsage = finalResult.getUsage();
108+
final var usage =
109+
new DefaultUsage(orchUsage.getPromptTokens(), null, orchUsage.getTotalTokens());
110+
final var metadata = new EmbeddingResponseMetadata(finalResult.getModel(), usage);
111+
112+
return new EmbeddingResponse(embeddings, metadata);
113+
}
114+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package com.sap.ai.sdk.orchestration.spring;
2+
3+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
5+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
6+
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
7+
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
8+
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
9+
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
10+
import static org.assertj.core.api.Assertions.assertThat;
11+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
12+
13+
import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
14+
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
15+
import com.sap.ai.sdk.orchestration.OrchestrationClient;
16+
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
17+
import java.io.IOException;
18+
import java.io.InputStream;
19+
import java.util.List;
20+
import java.util.Objects;
21+
import java.util.function.Function;
22+
import jdk.jfr.Description;
23+
import org.junit.jupiter.api.BeforeEach;
24+
import org.junit.jupiter.api.Test;
25+
import org.springframework.ai.document.MetadataMode;
26+
import org.springframework.ai.embedding.EmbeddingOptions;
27+
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
28+
import org.springframework.ai.embedding.EmbeddingRequest;
29+
30+
@WireMockTest
31+
class OrchestrationEmbeddingModelTest {
32+
33+
private static EmbeddingOptions options;
34+
private static OrchestrationSpringEmbeddingModel model;
35+
private final Function<String, InputStream> fileLoader =
36+
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));
37+
38+
@BeforeEach
39+
void setup(WireMockRuntimeInfo server) {
40+
options = EmbeddingOptionsBuilder.builder().withModel("text-embedding-3-small").build();
41+
42+
final var destination = DefaultHttpDestination.builder(server.getHttpBaseUrl()).build();
43+
final var client = new OrchestrationClient(destination);
44+
model = new OrchestrationSpringEmbeddingModel(options, client, MetadataMode.EMBED);
45+
}
46+
47+
@Test
48+
void testEmbeddingRequest() throws IOException {
49+
stubForEmbedding();
50+
51+
final var inputText = "Hi SAP Orchestration Service";
52+
final var embeddingRequest = new EmbeddingRequest(List.of(inputText), null);
53+
final var response = model.call(embeddingRequest);
54+
55+
assertThat(response.getResult().getOutput())
56+
.isEqualTo(
57+
new float[] {-0.003806071f, -0.01453408f, 0.037058588f, -0.012397106f, 0.0029582495f});
58+
assertThat(response.getMetadata().getModel()).isEqualTo(options.getModel());
59+
assertThat(response.getMetadata().getUsage().getPromptTokens()).isEqualTo(6);
60+
assertThat(response.getMetadata().getUsage().getTotalTokens()).isEqualTo(6);
61+
62+
verifyEmbeddingRequest();
63+
}
64+
65+
@Test
66+
void testEmbeddingText() throws IOException {
67+
stubForEmbedding();
68+
69+
final var inputText = "Hi SAP Orchestration Service";
70+
final var response = model.embed(inputText);
71+
assertThat(response)
72+
.isEqualTo(
73+
new float[] {-0.003806071f, -0.01453408f, 0.037058588f, -0.012397106f, 0.0029582495f});
74+
75+
verifyEmbeddingRequest();
76+
}
77+
78+
@Test
79+
@Description("Tests that model must name must be set and request option precedes over default")
80+
void testEmbeddingWithMissingModelNameThrows() {
81+
final var request =
82+
new EmbeddingRequest(List.of("Hello World"), EmbeddingOptionsBuilder.builder().build());
83+
84+
assertThatThrownBy(() -> model.call(request))
85+
.isInstanceOf(NullPointerException.class)
86+
.hasMessageContaining("EmbeddingOptions must provide the model name");
87+
}
88+
89+
private void stubForEmbedding() {
90+
stubFor(
91+
post(urlEqualTo("/v2/embeddings"))
92+
.willReturn(
93+
aResponse()
94+
.withStatus(200)
95+
.withHeader("Content-Type", "application/json")
96+
.withBodyFile("simpleEmbeddingResponse.json")));
97+
}
98+
99+
private void verifyEmbeddingRequest() throws IOException {
100+
try (var inputStream = fileLoader.apply("simpleEmbeddingRequest.json")) {
101+
var requestJson = new String(inputStream.readAllBytes());
102+
verify(
103+
postRequestedFor(urlEqualTo("/v2/embeddings")).withRequestBody(equalToJson(requestJson)));
104+
}
105+
}
106+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"request_id": "a99625bc-1dc2-9b95-82c1-c6175dc7407b",
3+
"final_result": {
4+
"object": "list",
5+
"data": [
6+
{
7+
"object": "embedding",
8+
"embedding": [
9+
-0.003806071,
10+
-0.01453408,
11+
0.037058588,
12+
-0.012397106,
13+
0.0029582495
14+
],
15+
"index": 0
16+
}
17+
],
18+
"model": "text-embedding-3-small",
19+
"usage": {
20+
"prompt_tokens": 6,
21+
"total_tokens": 6
22+
}
23+
}
24+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"config": {
3+
"modules": {
4+
"embeddings": {
5+
"model": {
6+
"name": "text-embedding-3-small",
7+
"params": {
8+
"encoding_format": "float"
9+
},
10+
"timeout": 600,
11+
"max_retries": 2
12+
}
13+
}
14+
}
15+
},
16+
"input": {
17+
"text": [
18+
"Hi SAP Orchestration Service"
19+
]
20+
}
21+
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,10 @@ Object chatMemory(
157157
}
158158
return response.getResult().getOutput().getText();
159159
}
160+
161+
@GetMapping("/embed/string")
162+
Object embedding(
163+
@Nullable @RequestParam(value = "format", required = false) final String format) {
164+
return service.embed("Hello, world!");
165+
}
160166
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
import com.sap.ai.sdk.orchestration.AzureFilterThreshold;
1010
import com.sap.ai.sdk.orchestration.DpiMasking;
1111
import com.sap.ai.sdk.orchestration.OrchestrationClientException;
12+
import com.sap.ai.sdk.orchestration.OrchestrationEmbeddingModel;
1213
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
1314
import com.sap.ai.sdk.orchestration.model.DPIEntities;
1415
import com.sap.ai.sdk.orchestration.spring.OrchestrationChatModel;
1516
import com.sap.ai.sdk.orchestration.spring.OrchestrationChatOptions;
17+
import com.sap.ai.sdk.orchestration.spring.OrchestrationSpringEmbeddingModel;
1618
import java.util.List;
1719
import java.util.Map;
1820
import java.util.Objects;
@@ -29,6 +31,7 @@
2931
import org.springframework.ai.chat.model.ChatResponse;
3032
import org.springframework.ai.chat.prompt.Prompt;
3133
import org.springframework.ai.chat.prompt.PromptTemplate;
34+
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
3235
import org.springframework.ai.support.ToolCallbacks;
3336
import org.springframework.ai.tool.ToolCallbackProvider;
3437
import org.springframework.beans.factory.annotation.Autowired;
@@ -263,4 +266,19 @@ public Translation responseFormat() {
263266
"How do I say 'AI is going to revolutionize the world' in dutch?", defaultOptions);
264267
return cl.prompt(prompt).call().entity(Translation.class);
265268
}
269+
270+
/**
271+
* Create an embedding for a given text using the Orchestration service.
272+
*
273+
* @param inputText the text to embed
274+
* @return the embedding as a float array
275+
*/
276+
@Nonnull
277+
public float[] embed(@Nonnull final String inputText) {
278+
val embedOptions =
279+
EmbeddingOptionsBuilder.builder()
280+
.withModel(OrchestrationEmbeddingModel.TEXT_EMBEDDING_3_SMALL.name())
281+
.build();
282+
return new OrchestrationSpringEmbeddingModel(embedOptions).embed(inputText);
283+
}
266284
}

0 commit comments

Comments
 (0)