Skip to content

Commit 3a8a128

Browse files
committed
add test for call and use low level api
1 parent d4f663c commit 3a8a128

File tree

2 files changed

+84
-13
lines changed

2 files changed

+84
-13
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiSpringEmbeddingModel.java

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package com.sap.ai.sdk.foundationmodels.openai.spring;
22

33
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
4-
import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingRequest;
5-
import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingResponse;
64
import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel;
5+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
6+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
7+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
8+
import java.util.Objects;
79
import java.util.stream.IntStream;
810
import javax.annotation.Nonnull;
11+
import lombok.NoArgsConstructor;
912
import org.springframework.ai.chat.metadata.DefaultUsage;
1013
import org.springframework.ai.document.Document;
1114
import org.springframework.ai.embedding.Embedding;
@@ -14,14 +17,17 @@
1417
import org.springframework.ai.embedding.EmbeddingResponse;
1518
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
1619

20+
@NoArgsConstructor
1721
public class OpenAiSpringEmbeddingModel implements EmbeddingModel {
1822

1923
@Override
2024
public EmbeddingResponse call(@Nonnull final EmbeddingRequest request) {
21-
22-
var modelName = request.getOptions().getModel();
25+
var modelName =
26+
Objects.requireNonNull(
27+
request.getOptions().getModel(), "Model name is required if client is not set");
2328
var client = OpenAiClient.forModel(new OpenAiModel(modelName, null));
24-
var openAiRequest = createOpenAiEmbeddingRequest(request);
29+
30+
var openAiRequest = createEmbeddingCreateRequest(request);
2531
var openAiResponse = client.embedding(openAiRequest);
2632

2733
return createSpringAiEmbeddingResponse(openAiResponse);
@@ -36,20 +42,22 @@ public float[] embed(@Nonnull final Document document) {
3642
"Only text type document supported. Metadata contains " + document.getMetadata());
3743
}
3844

39-
private OpenAiEmbeddingRequest createOpenAiEmbeddingRequest(
45+
private EmbeddingsCreateRequest createEmbeddingCreateRequest(
4046
@Nonnull final EmbeddingRequest request) {
41-
return new OpenAiEmbeddingRequest(request.getInstructions())
42-
.withDimensions(request.getOptions().getDimensions());
47+
return new EmbeddingsCreateRequest()
48+
.dimensions(request.getOptions().getDimensions())
49+
.input(EmbeddingsCreateRequestInput.create(request.getInstructions()));
4350
}
4451

45-
private EmbeddingResponse createSpringAiEmbeddingResponse(OpenAiEmbeddingResponse response) {
46-
var vectors = response.getEmbeddingVectors();
52+
private EmbeddingResponse createSpringAiEmbeddingResponse(EmbeddingsCreate200Response response) {
4753
var embeddings =
48-
IntStream.range(0, vectors.size()).mapToObj(i -> new Embedding(vectors.get(i), i)).toList();
54+
IntStream.range(0, response.getData().size())
55+
.mapToObj(i -> new Embedding(response.getData().get(i).getEmbedding(), i))
56+
.toList();
4957

50-
var openAiUsage = response.getOriginalResponse().getUsage();
58+
var openAiUsage = response.getUsage();
5159
var usage = new DefaultUsage(openAiUsage.getPromptTokens(), null, openAiUsage.getTotalTokens());
52-
var metadata = new EmbeddingResponseMetadata(response.getOriginalResponse().getModel(), usage);
60+
var metadata = new EmbeddingResponseMetadata(response.getModel(), usage);
5361

5462
return new EmbeddingResponse(embeddings, metadata);
5563
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package com.sap.ai.sdk.foundationmodels.openai.spring;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.mockito.ArgumentMatchers.any;
5+
import static org.mockito.Mockito.mock;
6+
import static org.mockito.Mockito.mockStatic;
7+
import static org.mockito.Mockito.when;
8+
9+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
10+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
11+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200ResponseDataInner;
12+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200ResponseUsage;
13+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
14+
import java.util.List;
15+
import org.junit.jupiter.api.BeforeEach;
16+
import org.junit.jupiter.api.Test;
17+
import org.springframework.ai.chat.metadata.DefaultUsage;
18+
import org.springframework.ai.embedding.Embedding;
19+
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
20+
import org.springframework.ai.embedding.EmbeddingRequest;
21+
import org.springframework.ai.embedding.EmbeddingResponse;
22+
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
23+
24+
class EmbeddingModelTest {
25+
private OpenAiSpringEmbeddingModel model;
26+
private OpenAiClient client;
27+
28+
@BeforeEach
29+
void setUp() {
30+
client = mock(OpenAiClient.class);
31+
model = new OpenAiSpringEmbeddingModel();
32+
}
33+
34+
@Test
35+
void testCall() {
36+
var request =
37+
new EmbeddingRequest(
38+
List.of("instructions"),
39+
EmbeddingOptionsBuilder.builder().withModel("model").withDimensions(128).build());
40+
41+
var vector = new float[] {0.0f};
42+
var expectedResponse =
43+
new EmbeddingResponse(
44+
List.of(new Embedding(vector, 0)),
45+
new EmbeddingResponseMetadata("model", new DefaultUsage(0, null, 0)));
46+
47+
var openAiResponse =
48+
new EmbeddingsCreate200Response()
49+
.data(List.of(new EmbeddingsCreate200ResponseDataInner().embedding(vector)))
50+
.model("model")
51+
.usage(new EmbeddingsCreate200ResponseUsage().promptTokens(0).totalTokens(0));
52+
53+
when(client.embedding(any(EmbeddingsCreateRequest.class))).thenReturn(openAiResponse);
54+
55+
try (var mockedStatic = mockStatic(OpenAiClient.class)) {
56+
mockedStatic.when(() -> OpenAiClient.forModel(any())).thenReturn(client);
57+
58+
var actualResponse = model.call(request);
59+
60+
assertThat(expectedResponse).usingRecursiveComparison().isEqualTo(actualResponse);
61+
}
62+
}
63+
}

0 commit comments

Comments
 (0)