Skip to content

Commit 9de1c1b

Browse files
authored
feat: [OpenAI] Embedding Request Convenience (#331)
* OpenAI Embedding convenience response and request * refactor imports * change method name and remove `@With` * Remove single token constructor in embedding request * Remove single token constructor in embedding request --------- Co-authored-by: Roshin Rajan Panackal <[email protected]>
1 parent 3c8e4f6 commit 9de1c1b

File tree

7 files changed

+181
-25
lines changed

7 files changed

+181
-25
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,29 @@ private void warnIfUnsupportedUsage() {
341341
}
342342

343343
/**
344-
* Get a vector representation of a given request with input that can be easily consumed by
345-
* machine learning models and algorithms.
344+
* Get a vector representation of a given request that can be easily consumed by machine learning
345+
* models and algorithms using high-level request object.
346+
*
347+
* @param request the request with input text.
348+
* @return the embedding response convenience object
349+
* @throws OpenAiClientException if the request fails
350+
* @see #embedding(EmbeddingsCreateRequest) for full confgurability.
351+
* @since 1.4.0
352+
*/
353+
@Beta
354+
@Nonnull
355+
public OpenAiEmbeddingResponse embedding(@Nonnull final OpenAiEmbeddingRequest request)
356+
throws OpenAiClientException {
357+
return new OpenAiEmbeddingResponse(embedding(request.createEmbeddingsCreateRequest()));
358+
}
359+
360+
/**
361+
* Get a vector representation of a given inputs using low-level request.
346362
*
347363
* @param request the request with input text.
348364
* @return the embedding output
349365
* @throws OpenAiClientException if the request fails
366+
* @see #embedding(OpenAiEmbeddingRequest) for conveninece api
350367
* @since 1.4.0
351368
*/
352369
@Beta
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
import com.google.common.annotations.Beta;
4+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
5+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
6+
import java.util.Collections;
7+
import java.util.List;
8+
import javax.annotation.Nonnull;
9+
import lombok.Value;
10+
11+
/**
12+
* Represents a request to create embeddings using OpenAI's API.
13+
*
14+
* <p>A high-level wrapper over the generated model class {@code EmbeddingsCreateRequest}, *
15+
* improving API usability for common use cases such as creation from a list of tokens.
16+
*
17+
* @since 1.4.0
18+
*/
19+
@Beta
20+
@Value
21+
public class OpenAiEmbeddingRequest {
22+
/** List of tokens to be embedded. */
23+
@Nonnull private final List<String> tokens;
24+
25+
/**
26+
* Constructs an OpenAiEmbeddingRequest from a list of strings.
27+
*
28+
* @param tokens a list of tokens to be embedded
29+
*/
30+
public OpenAiEmbeddingRequest(@Nonnull final List<String> tokens) {
31+
this.tokens = Collections.unmodifiableList(tokens);
32+
}
33+
34+
/**
35+
* Converts this request to an EmbeddingsCreateRequest.
36+
*
37+
* @return an EmbeddingsCreateRequest with the tokens to be embedded
38+
*/
39+
@Nonnull
40+
EmbeddingsCreateRequest createEmbeddingsCreateRequest() {
41+
if (tokens.size() == 1) {
42+
return new EmbeddingsCreateRequest()
43+
.input(EmbeddingsCreateRequestInput.create(tokens.get(0)));
44+
}
45+
46+
return new EmbeddingsCreateRequest().input(EmbeddingsCreateRequestInput.create(tokens));
47+
}
48+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
import static lombok.AccessLevel.NONE;
4+
import static lombok.AccessLevel.PACKAGE;
5+
6+
import com.google.common.annotations.Beta;
7+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
import javax.annotation.Nonnull;
11+
import lombok.AllArgsConstructor;
12+
import lombok.Setter;
13+
import lombok.Value;
14+
15+
/**
16+
* Represents a response from the OpenAI Embedding API.
17+
*
18+
* <p>A high-level wrapper over the generated model class {@code EmbeddingsCreate200Response},
19+
* improving API usability for common use cases, such as extracting embeddings.
20+
*
21+
* @since 1.4.0
22+
*/
23+
@Beta
24+
@Value
25+
@AllArgsConstructor(access = PACKAGE)
26+
@Setter(value = NONE)
27+
public class OpenAiEmbeddingResponse {
28+
29+
/** The original response from the OpenAI Embedding API. */
30+
@Nonnull EmbeddingsCreate200Response originalResponse;
31+
32+
/**
33+
* Read the embeddings from the response as a list of float arrays.
34+
*
35+
* @return a list of float arrays
36+
*/
37+
@Nonnull
38+
public List<float[]> getEmbeddingVectors() {
39+
final var embeddings = new ArrayList<float[]>();
40+
for (final var container : originalResponse.getData()) {
41+
42+
final var embeddingDecimals = container.getEmbedding();
43+
final var embeddingFloats = new float[embeddingDecimals.size()];
44+
45+
for (int i = 0; i < embeddingDecimals.size(); i++) {
46+
embeddingFloats[i] = embeddingDecimals.get(i).floatValue();
47+
}
48+
embeddings.add(embeddingFloats);
49+
}
50+
return embeddings;
51+
}
52+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
6+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
7+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
8+
import java.util.List;
9+
import lombok.SneakyThrows;
10+
import org.junit.jupiter.api.Test;
11+
12+
@WireMockTest
13+
class EmbeddingConvenienceTest {
14+
15+
@Test
16+
void createEmbeddingRequestWithMultipleTokens() {
17+
var request = new OpenAiEmbeddingRequest(List.of("token1", "token2", "token3"));
18+
var lowLevelRequest = request.createEmbeddingsCreateRequest();
19+
20+
assertThat(((EmbeddingsCreateRequestInput.InnerStrings) lowLevelRequest.getInput()).values())
21+
.containsExactly("token1", "token2", "token3");
22+
}
23+
24+
@SneakyThrows
25+
@Test
26+
void getEmbeddings() {
27+
var originalResponse =
28+
OpenAiUtils.getOpenAiObjectMapper()
29+
.readValue(
30+
getClass().getClassLoader().getResource("__files/embeddingResponse.json"),
31+
EmbeddingsCreate200Response.class);
32+
33+
var embeddings = new OpenAiEmbeddingResponse(originalResponse).getEmbeddingVectors();
34+
35+
assertThat(embeddings).isInstanceOf(List.class);
36+
assertThat(embeddings).hasSize(1);
37+
assertThat(embeddings)
38+
.containsExactly(new float[] {0.0f, 3.4028235E38f, 1.4E-45f, 1.23f, -4.56f});
39+
}
40+
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ContentFilterPromptResults;
3131
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest;
3232
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponseChoicesInner;
33-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
34-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
3533
import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
3634
import com.sap.ai.sdk.foundationmodels.openai.generated.model.PromptFilterResult;
3735
import io.vavr.control.Try;
@@ -244,21 +242,19 @@ void history() {
244242
void embedding() {
245243
stubForEmbedding();
246244

247-
final var result =
248-
client.embedding(
249-
new EmbeddingsCreateRequest()
250-
.input(EmbeddingsCreateRequestInput.create("Hello World")));
245+
final var response =
246+
client.embedding(new OpenAiEmbeddingRequest(List.of("Hello World"))).getOriginalResponse();
251247

252-
assertThat(result).isNotNull();
253-
assertThat(result.getModel()).isEqualTo("ada");
254-
assertThat(result.getObject()).isEqualTo("list");
248+
assertThat(response).isNotNull();
249+
assertThat(response.getModel()).isEqualTo("ada");
250+
assertThat(response.getObject()).isEqualTo("list");
255251

256-
assertThat(result.getUsage()).isNotNull();
257-
assertThat(result.getUsage().getPromptTokens()).isEqualTo(2);
258-
assertThat(result.getUsage().getTotalTokens()).isEqualTo(2);
252+
assertThat(response.getUsage()).isNotNull();
253+
assertThat(response.getUsage().getPromptTokens()).isEqualTo(2);
254+
assertThat(response.getUsage().getTotalTokens()).isEqualTo(2);
259255

260-
assertThat(result.getData()).isNotNull().hasSize(1);
261-
var embeddingData = result.getData().get(0);
256+
assertThat(response.getData()).isNotNull().hasSize(1);
257+
var embeddingData = response.getData().get(0);
262258
assertThat(embeddingData).isNotNull();
263259
assertThat(embeddingData.getObject()).isEqualTo("embedding");
264260
assertThat(embeddingData.getIndex()).isZero();

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiV2Test.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
1010
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
1111
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage;
12+
import java.util.ArrayList;
1213
import java.util.concurrent.atomic.AtomicInteger;
1314
import java.util.concurrent.atomic.AtomicReference;
1415
import lombok.extern.slf4j.Slf4j;
@@ -86,9 +87,13 @@ void chatCompletionTools() {
8687
void embedding() {
8788
final var embedding = service.embedding("Hello world");
8889

89-
assertThat(embedding.getData().get(0).getEmbedding()).hasSizeGreaterThan(1);
90-
assertThat(embedding.getModel()).isEqualTo("ada");
91-
assertThat(embedding.getObject()).isEqualTo("list");
90+
assertThat(embedding.getOriginalResponse().getData().get(0).getEmbedding())
91+
.hasSizeGreaterThan(1);
92+
assertThat(embedding.getEmbeddingVectors()).isInstanceOf(ArrayList.class);
93+
assertThat(embedding.getEmbeddingVectors().get(0)).isInstanceOf(float[].class);
94+
95+
assertThat(embedding.getOriginalResponse().getModel()).isEqualTo("ada");
96+
assertThat(embedding.getOriginalResponse().getObject()).isEqualTo("list");
9297
}
9398

9499
@Test

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest;
1111
import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionResponse;
1212
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
13+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingRequest;
14+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingResponse;
1315
import com.sap.ai.sdk.foundationmodels.openai.OpenAiImageItem;
1416
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
1517
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
16-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
17-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
18-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
1918
import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
2019
import java.util.List;
2120
import java.util.Map;
@@ -117,9 +116,8 @@ public OpenAiChatCompletionResponse chatCompletionTools(final int months) {
117116
* @return the embedding response
118117
*/
119118
@Nonnull
120-
public EmbeddingsCreate200Response embedding(@Nonnull final String input) {
121-
final var request =
122-
new EmbeddingsCreateRequest().input(EmbeddingsCreateRequestInput.create(input));
119+
public OpenAiEmbeddingResponse embedding(@Nonnull final String input) {
120+
final var request = new OpenAiEmbeddingRequest(List.of(input));
123121

124122
return OpenAiClient.forModel(TEXT_EMBEDDING_ADA_002).embedding(request);
125123
}

0 commit comments

Comments
 (0)