Skip to content

Commit b666f26

Browse files
committed
Constructor and code quality
- OpenAiSpringEmbeddingModel(client) - unit tests added - e2e test added - java docs
1 parent 3a8a128 commit b666f26

File tree

4 files changed

+114
-41
lines changed

4 files changed

+114
-41
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiSpringEmbeddingModel.java

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
package com.sap.ai.sdk.foundationmodels.openai.spring;
22

3+
import com.google.common.annotations.Beta;
34
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
4-
import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel;
55
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
66
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
77
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
88
import java.util.Objects;
99
import java.util.stream.IntStream;
1010
import javax.annotation.Nonnull;
11-
import lombok.NoArgsConstructor;
1211
import org.springframework.ai.chat.metadata.DefaultUsage;
1312
import org.springframework.ai.document.Document;
1413
import org.springframework.ai.embedding.Embedding;
@@ -17,29 +16,45 @@
1716
import org.springframework.ai.embedding.EmbeddingResponse;
1817
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
1918

20-
@NoArgsConstructor
19+
/**
20+
* SpringAI integration with {@link OpenAiClient} to generate embeddings.
21+
*
22+
* <p>This model transforms an {@link EmbeddingRequest} into an Azure OpenAI request, processes the
23+
* response, and returns a Spring AI {@link EmbeddingResponse}.
24+
*
25+
* @since 1.5.0
26+
*/
27+
@Beta
2128
public class OpenAiSpringEmbeddingModel implements EmbeddingModel {
2229

30+
private final OpenAiClient client;
31+
32+
/**
33+
* Constructs an {@code OpenAiSpringEmbeddingModel} with the specified {@link OpenAiClient} of
34+
* some model.
35+
*
36+
* @param client the OpenAI client
37+
*/
38+
public OpenAiSpringEmbeddingModel(@Nonnull final OpenAiClient client) {
39+
this.client = client;
40+
}
41+
2342
@Override
43+
@Nonnull
2444
public EmbeddingResponse call(@Nonnull final EmbeddingRequest request) {
25-
var modelName =
26-
Objects.requireNonNull(
27-
request.getOptions().getModel(), "Model name is required if client is not set");
28-
var client = OpenAiClient.forModel(new OpenAiModel(modelName, null));
29-
30-
var openAiRequest = createEmbeddingCreateRequest(request);
31-
var openAiResponse = client.embedding(openAiRequest);
45+
final var openAiRequest = createEmbeddingCreateRequest(request);
46+
final var openAiResponse = client.embedding(openAiRequest);
3247

3348
return createSpringAiEmbeddingResponse(openAiResponse);
3449
}
3550

3651
@Override
37-
public float[] embed(@Nonnull final Document document) {
52+
@Nonnull
53+
public float[] embed(@Nonnull final Document document) throws UnsupportedOperationException {
3854
if (document.isText()) {
39-
return embed(document.getText());
55+
return embed(Objects.requireNonNull(document.getText(), "Document text is null"));
4056
}
41-
throw new UnsupportedOperationException(
42-
"Only text type document supported. Metadata contains " + document.getMetadata());
57+
throw new UnsupportedOperationException("Only text type document supported for embedding");
4358
}
4459

4560
private EmbeddingsCreateRequest createEmbeddingCreateRequest(
@@ -49,15 +64,17 @@ private EmbeddingsCreateRequest createEmbeddingCreateRequest(
4964
.input(EmbeddingsCreateRequestInput.create(request.getInstructions()));
5065
}
5166

52-
private EmbeddingResponse createSpringAiEmbeddingResponse(EmbeddingsCreate200Response response) {
53-
var embeddings =
67+
private EmbeddingResponse createSpringAiEmbeddingResponse(
68+
@Nonnull final EmbeddingsCreate200Response response) {
69+
final var embeddings =
5470
IntStream.range(0, response.getData().size())
5571
.mapToObj(i -> new Embedding(response.getData().get(i).getEmbedding(), i))
5672
.toList();
5773

58-
var openAiUsage = response.getUsage();
59-
var usage = new DefaultUsage(openAiUsage.getPromptTokens(), null, openAiUsage.getTotalTokens());
60-
var metadata = new EmbeddingResponseMetadata(response.getModel(), usage);
74+
final var openAiUsage = response.getUsage();
75+
final var usage =
76+
new DefaultUsage(openAiUsage.getPromptTokens(), null, openAiUsage.getTotalTokens());
77+
final var metadata = new EmbeddingResponseMetadata(response.getModel(), usage);
6178

6279
return new EmbeddingResponse(embeddings, metadata);
6380
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/spring/EmbeddingModelTest.java

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
package com.sap.ai.sdk.foundationmodels.openai.spring;
22

33
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
5+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
46
import static org.mockito.ArgumentMatchers.any;
57
import static org.mockito.Mockito.mock;
6-
import static org.mockito.Mockito.mockStatic;
78
import static org.mockito.Mockito.when;
89

910
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
@@ -15,24 +16,23 @@
1516
import org.junit.jupiter.api.BeforeEach;
1617
import org.junit.jupiter.api.Test;
1718
import org.springframework.ai.chat.metadata.DefaultUsage;
19+
import org.springframework.ai.document.Document;
1820
import org.springframework.ai.embedding.Embedding;
1921
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
2022
import org.springframework.ai.embedding.EmbeddingRequest;
2123
import org.springframework.ai.embedding.EmbeddingResponse;
2224
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
2325

2426
class EmbeddingModelTest {
25-
private OpenAiSpringEmbeddingModel model;
2627
private OpenAiClient client;
2728

2829
@BeforeEach
2930
void setUp() {
3031
client = mock(OpenAiClient.class);
31-
model = new OpenAiSpringEmbeddingModel();
3232
}
3333

3434
@Test
35-
void testCall() {
35+
void callWithValidRequest() {
3636
var request =
3737
new EmbeddingRequest(
3838
List.of("instructions"),
@@ -52,12 +52,37 @@ void testCall() {
5252

5353
when(client.embedding(any(EmbeddingsCreateRequest.class))).thenReturn(openAiResponse);
5454

55-
try (var mockedStatic = mockStatic(OpenAiClient.class)) {
56-
mockedStatic.when(() -> OpenAiClient.forModel(any())).thenReturn(client);
55+
var actualResponse = new OpenAiSpringEmbeddingModel(client).call(request);
5756

58-
var actualResponse = model.call(request);
57+
assertThat(expectedResponse).usingRecursiveComparison().isEqualTo(actualResponse);
58+
}
59+
60+
@Test
61+
void embedDocumentInvokesDefaultMethod() {
62+
Document document = new Document("Some content");
63+
64+
OpenAiSpringEmbeddingModel model =
65+
new OpenAiSpringEmbeddingModel(client) {
66+
@Override
67+
public float[] embed(String text) {
68+
// For testing, just return any array
69+
return new float[] {1, 2, 3};
70+
}
71+
};
72+
73+
float[] result = model.embed(document);
74+
assertArrayEquals(new float[] {1, 2, 3}, result);
75+
}
76+
77+
@Test
78+
void embedDocumentThrowsException() {
79+
var document = mock(Document.class);
80+
when(document.isText()).thenReturn(false);
81+
82+
var model = new OpenAiSpringEmbeddingModel(client);
5983

60-
assertThat(expectedResponse).usingRecursiveComparison().isEqualTo(actualResponse);
61-
}
84+
assertThatThrownBy(() -> model.embed(document))
85+
.isInstanceOf(UnsupportedOperationException.class)
86+
.hasMessage("Only text type document supported for embedding");
6287
}
6388
}
Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,45 @@
11
package com.sap.ai.sdk.app.services;
22

3+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
4+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel;
35
import com.sap.ai.sdk.foundationmodels.openai.spring.OpenAiSpringEmbeddingModel;
46
import java.util.List;
7+
import javax.annotation.Nonnull;
8+
import org.springframework.ai.document.Document;
59
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
610
import org.springframework.ai.embedding.EmbeddingRequest;
711
import org.springframework.ai.embedding.EmbeddingResponse;
812
import org.springframework.stereotype.Service;
913

14+
/** Service class for Spring AI integration with OpenAI */
1015
@Service
1116
public class SpringAiOpenAiService {
12-
public EmbeddingResponse embeddingWithStrings(String modelName) {
1317

14-
var options =
15-
EmbeddingOptionsBuilder.builder().withModel(modelName).withDimensions(128).build();
18+
private final OpenAiClient client = OpenAiClient.forModel(OpenAiModel.TEXT_EMBEDDING_3_SMALL);
1619

17-
var springAiRequest =
18-
new EmbeddingRequest(
19-
List.of(
20-
"The quick brown fox jumps over the lazy dog.",
21-
"To be or not to be, that is the question."),
22-
options);
23-
return new OpenAiSpringEmbeddingModel().call(springAiRequest);
20+
/**
21+
* Embeds a list of strings using the OpenAI embedding model.
22+
*
23+
* @param strings the list of strings to embed
24+
* @return an {@code EmbeddingResponse} containing the embeddings and metadata
25+
*/
26+
@Nonnull
27+
public EmbeddingResponse embedWithEmbeddingRequest(@Nonnull final List<String> strings) {
28+
final var options = EmbeddingOptionsBuilder.builder().withDimensions(128).build();
29+
final var springAiRequest = new EmbeddingRequest(strings, options);
30+
31+
return new OpenAiSpringEmbeddingModel(client).call(springAiRequest);
32+
}
33+
34+
/**
35+
* Embeds the content of a document using the OpenAI embedding model.
36+
*
37+
* @param content the content of the document to embed
38+
* @return a float array representing the embedding of the document's content
39+
*/
40+
@Nonnull
41+
public float[] embedWithDocument(@Nonnull final String content) {
42+
final var document = new Document(content);
43+
return new OpenAiSpringEmbeddingModel(client).embed(document);
2444
}
2545
}

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,29 @@
33
import static org.assertj.core.api.Assertions.assertThat;
44

55
import com.sap.ai.sdk.app.services.SpringAiOpenAiService;
6+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel;
7+
import java.util.List;
68
import org.junit.jupiter.api.Test;
79

810
class SpringAiOpenAiTest {
911

10-
private SpringAiOpenAiService service = new SpringAiOpenAiService();
12+
private final SpringAiOpenAiService service = new SpringAiOpenAiService();
1113

1214
@Test
13-
void testEmbeddingWithStrings() {
15+
void testEmbedWithEmbeddingRequest() {
1416

15-
var response = service.embeddingWithStrings("text-embedding-3-small");
17+
var response =
18+
service.embedWithEmbeddingRequest(
19+
List.of(
20+
"The quick brown fox jumps over the lazy dog.",
21+
"To be or not to be, that is the question."));
1622

1723
assertThat(response).isNotNull();
1824
assertThat(response.getResults()).hasSize(2);
25+
assertThat(response.getResults().get(0).getOutput()).hasSize(128);
26+
assertThat(response.getMetadata().getUsage().getPromptTokens()).isNotNull();
27+
assertThat(response.getMetadata().getUsage().getTotalTokens()).isNotNull();
28+
assertThat(response.getMetadata().getModel())
29+
.isEqualTo(OpenAiModel.TEXT_EMBEDDING_3_SMALL.name());
1930
}
2031
}

0 commit comments

Comments
 (0)