Skip to content

Commit deacbfb

Browse files
committed
OpenAI Embedding convenience response and request
1 parent 34ff288 commit deacbfb

File tree

7 files changed

+200
-25
lines changed

7 files changed

+200
-25
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,29 @@ private void warnIfUnsupportedUsage() {
341341
}
342342

343343
/**
344-
* Get a vector representation of a given request with input that can be easily consumed by
345-
* machine learning models and algorithms.
344+
* Get a vector representation of a given request that can be easily consumed by machine learning
345+
* models and algorithms using high-level request object.
346+
*
347+
* @param request the request with input text.
348+
* @return the embedding response convenience object
349+
* @throws OpenAiClientException if the request fails
350+
* @see #embedding(EmbeddingsCreateRequest) for full confgurability.
351+
* @since 1.4.0
352+
*/
353+
@Beta
354+
@Nonnull
355+
public OpenAiEmbeddingResponse embedding(@Nonnull final OpenAiEmbeddingRequest request)
356+
throws OpenAiClientException {
357+
return new OpenAiEmbeddingResponse(embedding(request.createEmbeddingsCreateRequest()));
358+
}
359+
360+
/**
361+
* Get a vector representation of a given inputs using low-level request.
346362
*
347363
* @param request the request with input text.
348364
* @return the embedding output
349365
* @throws OpenAiClientException if the request fails
366+
* @see #embedding(OpenAiEmbeddingRequest) for conveninece api
350367
* @since 1.4.0
351368
*/
352369
@Beta
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
import com.google.common.annotations.Beta;
4+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
5+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
6+
import java.util.Collections;
7+
import java.util.List;
8+
import javax.annotation.Nonnull;
9+
import lombok.Value;
10+
import lombok.With;
11+
12+
/**
13+
* Represents a request to create embeddings using OpenAI's API.
14+
*
15+
* <p>A high-level wrapper over the generated model class {@code EmbeddingsCreateRequest}, *
16+
* improving API usability for common use cases such as creation from a list of tokens.
17+
*
18+
* @since 1.4.0
19+
*/
20+
@Beta
21+
@Value
22+
@With
23+
public class OpenAiEmbeddingRequest {
24+
/** List of tokens to be embedded. */
25+
@Nonnull private final List<String> tokens;
26+
27+
/**
28+
* Constructs an OpenAiEmbeddingRequest from a basic string.
29+
*
30+
* @param token the token to be embedded
31+
*/
32+
public OpenAiEmbeddingRequest(@Nonnull final String token) {
33+
this(List.of(token));
34+
}
35+
36+
/**
37+
* Constructs an OpenAiEmbeddingRequest from a list of strings.
38+
*
39+
* @param tokens a list of tokens to be embedded
40+
*/
41+
public OpenAiEmbeddingRequest(@Nonnull final List<String> tokens) {
42+
this.tokens = Collections.unmodifiableList(tokens);
43+
}
44+
45+
/**
46+
* Converts this request to an EmbeddingsCreateRequest.
47+
*
48+
* @return an EmbeddingsCreateRequest with the tokens to be embedded
49+
*/
50+
@Nonnull
51+
EmbeddingsCreateRequest createEmbeddingsCreateRequest() {
52+
if (tokens.size() == 1) {
53+
return new EmbeddingsCreateRequest()
54+
.input(EmbeddingsCreateRequestInput.create(tokens.get(0)));
55+
}
56+
57+
return new EmbeddingsCreateRequest().input(EmbeddingsCreateRequestInput.create(tokens));
58+
}
59+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
import static lombok.AccessLevel.NONE;
4+
import static lombok.AccessLevel.PACKAGE;
5+
6+
import com.google.common.annotations.Beta;
7+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
import javax.annotation.Nonnull;
11+
import lombok.AllArgsConstructor;
12+
import lombok.Setter;
13+
import lombok.Value;
14+
15+
/**
16+
* Represents a response from the OpenAI Embedding API.
17+
*
18+
* <p>A high-level wrapper over the generated model class {@code EmbeddingsCreate200Response},
19+
* improving API usability for common use cases, such as extracting embeddings.
20+
*
21+
* @since 1.4.0
22+
*/
23+
@Beta
24+
@Value
25+
@AllArgsConstructor(access = PACKAGE)
26+
@Setter(value = NONE)
27+
public class OpenAiEmbeddingResponse {
28+
29+
/** The original response from the OpenAI Embedding API. */
30+
@Nonnull EmbeddingsCreate200Response originalResponse;
31+
32+
/**
33+
* Read the embeddings from the response as a list of float arrays.
34+
*
35+
* @return a list of float arrays
36+
*/
37+
@Nonnull
38+
public List<float[]> getEmbeddings() {
39+
final var embeddings = new ArrayList<float[]>();
40+
for (final var container : originalResponse.getData()) {
41+
42+
final var embeddingDecimals = container.getEmbedding();
43+
final var embeddingFloats = new float[embeddingDecimals.size()];
44+
45+
for (int i = 0; i < embeddingDecimals.size(); i++) {
46+
embeddingFloats[i] = embeddingDecimals.get(i).floatValue();
47+
}
48+
embeddings.add(embeddingFloats);
49+
}
50+
return embeddings;
51+
}
52+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
6+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
7+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
8+
import java.util.List;
9+
import lombok.SneakyThrows;
10+
import org.junit.jupiter.api.Test;
11+
12+
@WireMockTest
13+
class EmbeddingConvenienceTest {
14+
15+
@Test
16+
void createEmbeddingRequestWithSingleToken() {
17+
var request = new OpenAiEmbeddingRequest("token1");
18+
var lowLevelRequest = request.createEmbeddingsCreateRequest();
19+
20+
assertThat(((EmbeddingsCreateRequestInput.InnerString) lowLevelRequest.getInput()))
21+
.usingRecursiveComparison()
22+
.isEqualTo(EmbeddingsCreateRequestInput.create("token1"));
23+
}
24+
25+
@Test
26+
void createEmbeddingRequestWithMultipleTokens() {
27+
var request = new OpenAiEmbeddingRequest(List.of("token1", "token2", "token3"));
28+
var lowLevelRequest = request.createEmbeddingsCreateRequest();
29+
30+
assertThat(((EmbeddingsCreateRequestInput.InnerStrings) lowLevelRequest.getInput()).values())
31+
.containsExactly("token1", "token2", "token3");
32+
}
33+
34+
@SneakyThrows
35+
@Test
36+
void getEmbeddings() {
37+
var originalResponse =
38+
OpenAiUtils.getOpenAiObjectMapper()
39+
.readValue(
40+
getClass().getClassLoader().getResource("__files/embeddingResponse.json"),
41+
EmbeddingsCreate200Response.class);
42+
43+
var response = new OpenAiEmbeddingResponse(originalResponse);
44+
45+
assertThat(response.getEmbeddings())
46+
.containsExactly(new float[] {0.0f, 3.4028235E38f, 1.4E-45f, 1.23f, -4.56f});
47+
}
48+
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ContentFilterPromptResults;
3434
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest;
3535
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponseChoicesInner;
36-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
37-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
3836
import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
3937
import com.sap.ai.sdk.foundationmodels.openai.generated.model.PromptFilterResult;
4038
import io.vavr.control.Try;
@@ -247,21 +245,19 @@ void history() {
247245
void embedding() {
248246
stubForEmbedding();
249247

250-
final var result =
251-
client.embedding(
252-
new EmbeddingsCreateRequest()
253-
.input(EmbeddingsCreateRequestInput.create("Hello World")));
248+
final var response =
249+
client.embedding(new OpenAiEmbeddingRequest("Hello World")).getOriginalResponse();
254250

255-
assertThat(result).isNotNull();
256-
assertThat(result.getModel()).isEqualTo("ada");
257-
assertThat(result.getObject()).isEqualTo("list");
251+
assertThat(response).isNotNull();
252+
assertThat(response.getModel()).isEqualTo("ada");
253+
assertThat(response.getObject()).isEqualTo("list");
258254

259-
assertThat(result.getUsage()).isNotNull();
260-
assertThat(result.getUsage().getPromptTokens()).isEqualTo(2);
261-
assertThat(result.getUsage().getTotalTokens()).isEqualTo(2);
255+
assertThat(response.getUsage()).isNotNull();
256+
assertThat(response.getUsage().getPromptTokens()).isEqualTo(2);
257+
assertThat(response.getUsage().getTotalTokens()).isEqualTo(2);
262258

263-
assertThat(result.getData()).isNotNull().hasSize(1);
264-
var embeddingData = result.getData().get(0);
259+
assertThat(response.getData()).isNotNull().hasSize(1);
260+
var embeddingData = response.getData().get(0);
265261
assertThat(embeddingData).isNotNull();
266262
assertThat(embeddingData.getObject()).isEqualTo("embedding");
267263
assertThat(embeddingData.getIndex()).isZero();

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/NewOpenAiTest.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
1010
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
1111
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage;
12+
import java.util.ArrayList;
1213
import java.util.concurrent.atomic.AtomicInteger;
1314
import java.util.concurrent.atomic.AtomicReference;
1415
import lombok.extern.slf4j.Slf4j;
@@ -87,9 +88,13 @@ void chatCompletionTools() {
8788
void embedding() {
8889
final var embedding = service.embedding("Hello world");
8990

90-
assertThat(embedding.getData().get(0).getEmbedding()).hasSizeGreaterThan(1);
91-
assertThat(embedding.getModel()).isEqualTo("ada");
92-
assertThat(embedding.getObject()).isEqualTo("list");
91+
assertThat(embedding.getOriginalResponse().getData().get(0).getEmbedding())
92+
.hasSizeGreaterThan(1);
93+
assertThat(embedding.getEmbeddings()).isInstanceOf(ArrayList.class);
94+
assertThat(embedding.getEmbeddings().get(0)).isInstanceOf(float[].class);
95+
96+
assertThat(embedding.getOriginalResponse().getModel()).isEqualTo("ada");
97+
assertThat(embedding.getOriginalResponse().getObject()).isEqualTo("list");
9398
}
9499

95100
@Test

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/NewOpenAiService.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest;
1515
import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionResponse;
1616
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
17+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingRequest;
18+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingResponse;
1719
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
1820
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoice;
1921
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoiceFunction;
@@ -26,9 +28,6 @@
2628
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
2729
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest;
2830
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse;
29-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
30-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
31-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
3231
import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
3332
import java.net.URI;
3433
import java.util.List;
@@ -152,9 +151,8 @@ public OpenAiChatCompletionResponse chatCompletionTools(final int months) {
152151
* @return the embedding response
153152
*/
154153
@Nonnull
155-
public EmbeddingsCreate200Response embedding(@Nonnull final String input) {
156-
final var request =
157-
new EmbeddingsCreateRequest().input(EmbeddingsCreateRequestInput.create(input));
154+
public OpenAiEmbeddingResponse embedding(@Nonnull final String input) {
155+
final var request = new OpenAiEmbeddingRequest(input);
158156

159157
return OpenAiClient.forModel(TEXT_EMBEDDING_ADA_002).embedding(request);
160158
}

0 commit comments

Comments
 (0)