Skip to content

Commit e909edc

Browse files
committed
read embedding response from json
1 parent e7f14bc commit e909edc

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/spring/EmbeddingModelTest.java

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,21 @@
66
import static org.mockito.Mockito.mock;
77
import static org.mockito.Mockito.when;
88

9+
import com.fasterxml.jackson.databind.ObjectMapper;
910
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
1011
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
11-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200ResponseDataInner;
12-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200ResponseUsage;
1312
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
1413
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
1514
import java.util.List;
1615
import java.util.function.Consumer;
16+
import lombok.SneakyThrows;
1717
import lombok.val;
1818
import org.junit.jupiter.api.BeforeEach;
1919
import org.junit.jupiter.api.DisplayName;
2020
import org.junit.jupiter.api.Test;
21-
import org.springframework.ai.chat.metadata.DefaultUsage;
2221
import org.springframework.ai.document.Document;
23-
import org.springframework.ai.embedding.Embedding;
2422
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
2523
import org.springframework.ai.embedding.EmbeddingRequest;
26-
import org.springframework.ai.embedding.EmbeddingResponse;
27-
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
2824

2925
class EmbeddingModelTest {
3026
private OpenAiClient client;
@@ -34,18 +30,19 @@ void setUp() {
3430
client = mock(OpenAiClient.class);
3531
}
3632

33+
@SneakyThrows
3734
@Test
3835
@DisplayName("Call with embedding request containing valid options")
3936
void testCallWithValidEmbeddingRequest() {
4037
val texts = List.of("Some text");
4138
val springAiRequest =
4239
new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().withDimensions(128).build());
4340

44-
val vector = new float[] {0.0f};
4541
val expectedOpenAiResponse =
46-
new EmbeddingsCreate200Response()
47-
.data(List.of(new EmbeddingsCreate200ResponseDataInner().embedding(vector)))
48-
.usage(new EmbeddingsCreate200ResponseUsage().promptTokens(0).totalTokens(0));
42+
new ObjectMapper()
43+
.readValue(
44+
getClass().getClassLoader().getResource("__files/embeddingResponse.json"),
45+
EmbeddingsCreate200Response.class);
4946

5047
val expectedOpenAiRequest =
5148
new EmbeddingsCreateRequest()
@@ -57,15 +54,11 @@ void testCallWithValidEmbeddingRequest() {
5754

5855
val actualSpringAiResponse = new OpenAiSpringEmbeddingModel(client).call(springAiRequest);
5956

60-
val modelName = ""; // defined by client object and options not honoured
61-
val expectedSpringAiResponse =
62-
new EmbeddingResponse(
63-
List.of(new Embedding(vector, 0)),
64-
new EmbeddingResponseMetadata(modelName, new DefaultUsage(0, null, 0)));
65-
66-
assertThat(expectedSpringAiResponse)
67-
.usingRecursiveComparison()
68-
.isEqualTo(actualSpringAiResponse);
57+
assertThat(actualSpringAiResponse).isNotNull();
58+
assertThat(actualSpringAiResponse.getResult().getOutput())
59+
.isEqualTo(new float[] {0.0f, 3.4028235E38f, 1.4E-45f, 1.23f, -4.56f});
60+
assertThat(actualSpringAiResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(2);
61+
assertThat(actualSpringAiResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(2);
6962
}
7063

7164
@Test
@@ -83,27 +76,28 @@ void testCallWithModelOptionSetThrows() {
8376
"Do not set a model in EmbeddingOptions, as the OpenAiClient already defines the model.");
8477
}
8578

79+
@SneakyThrows
8680
@Test
8781
@DisplayName("Embed document with text content")
8882
void testEmbedDocument() {
8983
Document document = new Document("Some content");
9084

91-
val vector = new float[] {1, 2, 3};
92-
val openAiResponse =
93-
new EmbeddingsCreate200Response()
94-
.data(List.of(new EmbeddingsCreate200ResponseDataInner().embedding(vector)))
95-
.usage(new EmbeddingsCreate200ResponseUsage().promptTokens(0).totalTokens(0));
85+
val expectedOpenAiResponse =
86+
new ObjectMapper()
87+
.readValue(
88+
getClass().getClassLoader().getResource("__files/embeddingResponse.json"),
89+
EmbeddingsCreate200Response.class);
9690

9791
val expectedOpenAiRequest =
9892
new EmbeddingsCreateRequest()
9993
.input(EmbeddingsCreateRequestInput.create(List.of(document.getFormattedContent())));
10094

10195
when(client.embedding(assertArg(assertRecursiveEquals(expectedOpenAiRequest))))
102-
.thenReturn(openAiResponse);
96+
.thenReturn(expectedOpenAiResponse);
10397

10498
float[] result = new OpenAiSpringEmbeddingModel(client).embed(document);
10599

106-
assertThat(result).isEqualTo(new float[] {1, 2, 3});
100+
assertThat(result).isEqualTo(new float[] {0.0f, 3.4028235E38f, 1.4E-45f, 1.23f, -4.56f});
107101
}
108102

109103
private static <T> Consumer<T> assertRecursiveEquals(T expected) {

0 commit comments

Comments
 (0)