Skip to content

Commit 8e94e2d

Browse files
rpanackalJonas-IsrCharlesDuboisSAP
authored
feat: [OpenAI] SpringAI integration for embedding calls (#375)
* An initial version for embedding request integration - Request convenience extended - client per request (reassess) - full metadat loading on request - full metadata loading on response (reassess) * add test for `call` and use low level api * Constructor and code quality - OpenAiSpringEmbeddingModel(client) - unit tests added - e2e test added - java docs * Sample app controller and minor test update - make embedding document test response - add controller for spring open ai - include in static files of sample app * Remove drive-by changes * Provide documentation+release notes and add test for overlapping model setting - test improvements * Fix merge * Fix hard line wrap on html * Minor test and release note changes - prefix with "test" on all test method - update link in release note - update test display name * fix type * Additional constructor that take metadata mode * Additional constructor that take metadata mode * Charles preferences * Update link in static files * read embedding response from json * missed model name test * Enforcer plugin config for optional dependency and replicate import test for spring-ai - sample app ignored from ban * Fix enforcer config, ported import check changes to orchestration --------- Co-authored-by: Roshin Rajan Panackal <[email protected]> Co-authored-by: Jonas-Isr <[email protected]> Co-authored-by: I538344 <[email protected]>
1 parent 62faaab commit 8e94e2d

File tree

13 files changed

+461
-18
lines changed

13 files changed

+461
-18
lines changed

docs/guides/SPRING_AI_INTEGRATION.md

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
## Table of Contents
44

55
- [Introduction](#introduction)
6-
- [Orchestration Chat Completion](#orchestration-chat-completion)
7-
- [Orchestration Masking](#orchestration-masking)
8-
- [Stream chat completion](#stream-chat-completion)
9-
- [Tool Calling](#tool-calling)
10-
- [Chat Memory](#chat-memory)
6+
- [Orchestration](#orchestration)
7+
- [Chat Completion](#chat-completion)
8+
- [Masking](#masking)
9+
- [Stream chat completion](#stream-chat-completion)
10+
- [Tool Calling](#tool-calling)
11+
- [Chat Memory](#chat-memory)
12+
- [OpenAI](#openai)
13+
- [Embedding](#embedding)
14+
1115

1216
## Introduction
1317

@@ -36,10 +40,12 @@ First, add the Spring AI dependency to your `pom.xml`:
3640
> [!NOTE]
3741
> Note that currently no stable version of Spring AI exists just yet.
3842
> The AI SDK currently uses the [M6 milestone](https://spring.io/blog/2025/02/14/spring-ai-1-0-0-m6-released).
39-
>
43+
>
4044
> Please be aware that future versions of the AI SDK may increase the Spring AI version.
4145
42-
## Orchestration Chat Completion
46+
## Orchestration
47+
48+
### Chat Completion
4349

4450
The Orchestration client is integrated in Spring AI classes:
4551

@@ -54,7 +60,7 @@ ChatResponse response = client.call(prompt);
5460

5561
Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java).
5662

57-
## Orchestration Masking
63+
### Masking
5864

5965
Configure Orchestration modules withing Spring AI:
6066

@@ -78,7 +84,7 @@ ChatResponse response = client.call(prompt);
7884
Please
7985
find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java).
8086

81-
## Stream chat completion
87+
### Stream chat completion
8288

8389
It's possible to pass a stream of chat completion delta elements, e.g. from the application backend
8490
to the frontend in real-time.
@@ -94,15 +100,15 @@ Prompt prompt =
94100
Flux<ChatResponse> flux = client.stream(prompt);
95101

96102
// also possible to keep only the chat completion text
97-
Flux<String> responseFlux =
103+
Flux<String> responseFlux =
98104
flux.map(chatResponse -> chatResponse.getResult().getOutput().getContent());
99105
```
100106

101107
_Note: A Spring endpoint can return `Flux` instead of `ResponseEntity`._
102108

103109
Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java).
104110

105-
## Tool Calling
111+
### Tool Calling
106112

107113
First define a function that will be called by the LLM:
108114

@@ -161,3 +167,24 @@ String content2 = cl.prompt(prompt2).call().content();
161167
```
162168

163169
Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java).
170+
171+
## OpenAI
172+
173+
### Introduction
174+
175+
Our OpenAI client is integrated in Spring AI classes:
176+
177+
### Embedding
178+
179+
Here is how to obtain embedding vectors for a list of strings:
180+
181+
You first initialize the OpenAI client for your model of choice and attach it `OpenAiSpringEmbeddingModel` object.
182+
183+
```java
184+
OpenAiClient client = OpenAiClient.forModel(OpenAiModel.TEXT_EMBEDDING_3_SMALL);
185+
OpenAiSpringEmbeddingModel embeddingModel = new OpenAiSpringEmbeddingModel(client);
186+
List<String> texts = List.of("Hello", "World");
187+
float[] embeddings = embeddingModel.embed(texts);
188+
```
189+
190+
Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java).

docs/release-notes/release_notes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
- [Orchestration] [Prompt templates can be consumed from registry.](https://github.com/SAP/ai-sdk-java/tree/main/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md#Chat-completion-with-Templates)
1717
- [Orchestration] [Masking is now available on grounding.](https://github.com/SAP/ai-sdk-java/tree/main/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md#mask-grounding)
1818
- [Orchestration] [Grounding via *help.sap.com* is enabled.](https://github.com/SAP/ai-sdk-java/tree/main/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md#grounding)
19+
- [OpenAI] [Spring AI integration for embedding calls.](https://github.com/SAP/ai-sdk-java/tree/main/docs/guides/SPRING_AI_INTEGRATION.md#embedding)
1920

2021
### 📈 Improvements
2122

foundation-models/openai/pom.xml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@
8989
<groupId>com.google.guava</groupId>
9090
<artifactId>guava</artifactId>
9191
</dependency>
92+
<dependency>
93+
<groupId>org.springframework.ai</groupId>
94+
<artifactId>spring-ai-core</artifactId>
95+
<optional>true</optional>
96+
</dependency>
9297
<!-- scope "provided" -->
9398
<dependency>
9499
<groupId>org.projectlombok</groupId>
@@ -121,5 +126,10 @@
121126
<artifactId>mockito-core</artifactId>
122127
<scope>test</scope>
123128
</dependency>
129+
<dependency>
130+
<groupId>com.github.javaparser</groupId>
131+
<artifactId>javaparser-core</artifactId>
132+
<scope>test</scope>
133+
</dependency>
124134
</dependencies>
125135
</project>
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package com.sap.ai.sdk.foundationmodels.openai.spring;
2+
3+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
4+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
5+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
6+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
7+
import java.util.Objects;
8+
import java.util.stream.IntStream;
9+
import javax.annotation.Nonnull;
10+
import org.springframework.ai.chat.metadata.DefaultUsage;
11+
import org.springframework.ai.document.Document;
12+
import org.springframework.ai.document.MetadataMode;
13+
import org.springframework.ai.embedding.Embedding;
14+
import org.springframework.ai.embedding.EmbeddingModel;
15+
import org.springframework.ai.embedding.EmbeddingRequest;
16+
import org.springframework.ai.embedding.EmbeddingResponse;
17+
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
18+
19+
/**
20+
* SpringAI integration with {@link OpenAiClient} to generate embeddings.
21+
*
22+
* <p>This model transforms an {@link EmbeddingRequest} into an Azure OpenAI request, processes the
23+
* response, and returns a Spring AI {@link EmbeddingResponse}.
24+
*
25+
* @since 1.5.0
26+
*/
27+
public class OpenAiSpringEmbeddingModel implements EmbeddingModel {
28+
29+
private final OpenAiClient client;
30+
private final MetadataMode metadataMode;
31+
32+
/**
33+
* Constructs an {@code OpenAiSpringEmbeddingModel} with the specified {@link OpenAiClient} of
34+
* some model.
35+
*
36+
* @param client the OpenAI client
37+
*/
38+
public OpenAiSpringEmbeddingModel(@Nonnull final OpenAiClient client) {
39+
this(client, MetadataMode.EMBED);
40+
}
41+
42+
/**
43+
* Constructs an {@code OpenAiSpringEmbeddingModel} with the specified {@link OpenAiClient} of
44+
* some model and metadata mode.
45+
*
46+
* <p>The metadata mode is used by content formatter to determine which metadata to include in the
47+
* resulting content. Currently, the formatter is only effective for calls to {@link
48+
* #embed(Document)}.
49+
*
50+
* @param client the OpenAI client
51+
* @param metadataMode the metadata mode
52+
*/
53+
public OpenAiSpringEmbeddingModel(
54+
@Nonnull final OpenAiClient client, @Nonnull final MetadataMode metadataMode) {
55+
this.client = client;
56+
this.metadataMode = metadataMode;
57+
}
58+
59+
/**
60+
* {@inheritDoc}
61+
*
62+
* @throws IllegalArgumentException if {@code request.getOptions().getModel()} is not null.
63+
*/
64+
@Override
65+
@Nonnull
66+
public EmbeddingResponse call(@Nonnull final EmbeddingRequest request)
67+
throws IllegalArgumentException {
68+
69+
if (request.getOptions().getModel() != null) {
70+
throw new IllegalArgumentException(
71+
"Do not set a model in EmbeddingOptions, as the OpenAiClient already defines the model.");
72+
}
73+
74+
final var openAiRequest = createEmbeddingsCreateRequest(request);
75+
final var openAiResponse = client.embedding(openAiRequest);
76+
77+
return createSpringAiEmbeddingResponse(openAiResponse);
78+
}
79+
80+
@Override
81+
@Nonnull
82+
public float[] embed(@Nonnull final Document document) throws UnsupportedOperationException {
83+
return embed(
84+
Objects.requireNonNull(
85+
document.getFormattedContent(this.metadataMode),
86+
"Formatted content of the document should not be null."));
87+
}
88+
89+
private EmbeddingsCreateRequest createEmbeddingsCreateRequest(
90+
@Nonnull final EmbeddingRequest request) {
91+
return new EmbeddingsCreateRequest()
92+
.dimensions(request.getOptions().getDimensions())
93+
.input(EmbeddingsCreateRequestInput.create(request.getInstructions()));
94+
}
95+
96+
private EmbeddingResponse createSpringAiEmbeddingResponse(
97+
@Nonnull final EmbeddingsCreate200Response response) {
98+
final var embeddings =
99+
IntStream.range(0, response.getData().size())
100+
.mapToObj(i -> new Embedding(response.getData().get(i).getEmbedding(), i))
101+
.toList();
102+
103+
final var openAiUsage = response.getUsage();
104+
final var usage =
105+
new DefaultUsage(openAiUsage.getPromptTokens(), null, openAiUsage.getTotalTokens());
106+
final var metadata = new EmbeddingResponseMetadata(response.getModel(), usage);
107+
108+
return new EmbeddingResponse(embeddings, metadata);
109+
}
110+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package com.sap.ai.sdk.foundationmodels.openai.spring;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
5+
import static org.mockito.ArgumentMatchers.assertArg;
6+
import static org.mockito.Mockito.mock;
7+
import static org.mockito.Mockito.when;
8+
9+
import com.fasterxml.jackson.databind.ObjectMapper;
10+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
11+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
12+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
13+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
14+
import java.util.List;
15+
import java.util.function.Consumer;
16+
import lombok.SneakyThrows;
17+
import lombok.val;
18+
import org.junit.jupiter.api.BeforeEach;
19+
import org.junit.jupiter.api.DisplayName;
20+
import org.junit.jupiter.api.Test;
21+
import org.springframework.ai.document.Document;
22+
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
23+
import org.springframework.ai.embedding.EmbeddingRequest;
24+
25+
class EmbeddingModelTest {
26+
private OpenAiClient client;
27+
28+
@BeforeEach
29+
void setUp() {
30+
client = mock(OpenAiClient.class);
31+
}
32+
33+
@SneakyThrows
34+
@Test
35+
@DisplayName("Call with embedding request containing valid options")
36+
void testCallWithValidEmbeddingRequest() {
37+
val texts = List.of("Some text");
38+
val springAiRequest =
39+
new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().withDimensions(128).build());
40+
41+
val expectedOpenAiResponse =
42+
new ObjectMapper()
43+
.readValue(
44+
getClass().getClassLoader().getResource("__files/embeddingResponse.json"),
45+
EmbeddingsCreate200Response.class);
46+
47+
val expectedOpenAiRequest =
48+
new EmbeddingsCreateRequest()
49+
.input(EmbeddingsCreateRequestInput.create(texts))
50+
.dimensions(128);
51+
52+
when(client.embedding(assertArg(assertRecursiveEquals(expectedOpenAiRequest))))
53+
.thenReturn(expectedOpenAiResponse);
54+
55+
val actualSpringAiResponse = new OpenAiSpringEmbeddingModel(client).call(springAiRequest);
56+
57+
assertThat(actualSpringAiResponse).isNotNull();
58+
assertThat(actualSpringAiResponse.getResult().getOutput())
59+
.isEqualTo(new float[] {0.0f, 3.4028235E38f, 1.4E-45f, 1.23f, -4.56f});
60+
assertThat(actualSpringAiResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(2);
61+
assertThat(actualSpringAiResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(2);
62+
assertThat(actualSpringAiResponse.getMetadata().getModel()).isEqualTo("ada");
63+
}
64+
65+
@Test
66+
@DisplayName("Call with embedding request with model option set throws exception")
67+
void testCallWithModelOptionSetThrows() {
68+
val springAiRequest =
69+
new EmbeddingRequest(
70+
List.of("Some text"), EmbeddingOptionsBuilder.builder().withModel("model").build());
71+
72+
val model = new OpenAiSpringEmbeddingModel(client);
73+
74+
assertThatThrownBy(() -> model.call(springAiRequest))
75+
.isInstanceOf(IllegalArgumentException.class)
76+
.hasMessage(
77+
"Do not set a model in EmbeddingOptions, as the OpenAiClient already defines the model.");
78+
}
79+
80+
@SneakyThrows
81+
@Test
82+
@DisplayName("Embed document with text content")
83+
void testEmbedDocument() {
84+
Document document = new Document("Some content");
85+
86+
val expectedOpenAiResponse =
87+
new ObjectMapper()
88+
.readValue(
89+
getClass().getClassLoader().getResource("__files/embeddingResponse.json"),
90+
EmbeddingsCreate200Response.class);
91+
92+
val expectedOpenAiRequest =
93+
new EmbeddingsCreateRequest()
94+
.input(EmbeddingsCreateRequestInput.create(List.of(document.getFormattedContent())));
95+
96+
when(client.embedding(assertArg(assertRecursiveEquals(expectedOpenAiRequest))))
97+
.thenReturn(expectedOpenAiResponse);
98+
99+
float[] result = new OpenAiSpringEmbeddingModel(client).embed(document);
100+
101+
assertThat(result).isEqualTo(new float[] {0.0f, 3.4028235E38f, 1.4E-45f, 1.23f, -4.56f});
102+
}
103+
104+
private static <T> Consumer<T> assertRecursiveEquals(T expected) {
105+
return (actual) -> {
106+
assertThat(actual).usingRecursiveComparison().isEqualTo(expected);
107+
};
108+
}
109+
}

0 commit comments

Comments
 (0)