Skip to content

Commit 2adf1e4

Browse files
committed
Embedding convenience similar to chat completion request
1 parent 5518ce0 commit 2adf1e4

File tree

5 files changed

+95
-12
lines changed

5 files changed

+95
-12
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionResponse;
2323
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreate200Response;
2424
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreateRequest;
25-
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreateRequestInput;
2625
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
2726
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
2827
import com.sap.cloud.sdk.cloudplatform.connectivity.Destination;
@@ -349,19 +348,18 @@ private void warnIfUnsupportedUsage() {
349348
}
350349

351350
/**
352-
* Get a vector representation of a given string input that can be easily consumed by machine
353-
* learning models and algorithms.
351+
* Get a vector representation of a given convenience request that can be easily consumed by
352+
* machine learning models and algorithms.
354353
*
355-
* @param input the input text.
354+
* @param request the embedding request.
356355
* @return the embedding output
357356
* @throws OpenAiClientException if the request fails
358357
* @since 1.3.0
359358
*/
360359
@Nonnull
361-
public EmbeddingsCreate200Response embedding(@Nonnull final String input)
360+
public EmbeddingsCreate200Response embedding(@Nonnull final OpenAiEmbeddingRequest request)
362361
throws OpenAiClientException {
363-
return embedding(
364-
new EmbeddingsCreateRequest().input(EmbeddingsCreateRequestInput.create(input)));
362+
return embedding(request.toEmbeddingsCreateRequest());
365363
}
366364

367365
/**
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
import com.google.common.annotations.Beta;
4+
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreateRequest;
5+
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreateRequestInput;
6+
import java.util.ArrayList;
7+
import java.util.Arrays;
8+
import java.util.List;
9+
import javax.annotation.Nonnull;
10+
import lombok.experimental.Accessors;
11+
12+
/**
13+
* Represents a request to create embeddings using OpenAI's API.
14+
*
15+
* @since 1.3.0
16+
*/
17+
@Beta
18+
@Accessors(fluent = true)
19+
public class OpenAiEmbeddingRequest {
20+
/** List of tokens to be embedded. */
21+
@Nonnull private final List<String> tokens = new ArrayList<>();
22+
23+
/**
24+
* Constructs an OpenAiEmbeddingRequest with a single token.
25+
*
26+
* @param token the token to be embedded
27+
*/
28+
public OpenAiEmbeddingRequest(@Nonnull final String token) {
29+
tokens.add(token);
30+
}
31+
32+
/**
33+
* Constructs an OpenAiEmbeddingRequest with multiple tokens.
34+
*
35+
* @param token the first token to be embedded
36+
* @param tokens additional tokens to be embedded
37+
*/
38+
public OpenAiEmbeddingRequest(@Nonnull final String token, @Nonnull final String... tokens) {
39+
this.tokens.add(token);
40+
this.tokens.addAll(Arrays.asList(tokens));
41+
}
42+
43+
/**
44+
* Converts this request to an EmbeddingsCreateRequest.
45+
*
46+
* @return an EmbeddingsCreateRequest with the tokens to be embedded
47+
*/
48+
@Nonnull
49+
public EmbeddingsCreateRequest toEmbeddingsCreateRequest() {
50+
if (tokens.size() == 1) {
51+
return new EmbeddingsCreateRequest()
52+
.input(EmbeddingsCreateRequestInput.create(tokens.get(0)));
53+
}
54+
55+
return new EmbeddingsCreateRequest().input(EmbeddingsCreateRequestInput.create(tokens));
56+
}
57+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
6+
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreateRequestInput;
7+
import org.junit.jupiter.api.Test;
8+
9+
@WireMockTest
10+
class EmbeddingConvenienceTest {
11+
12+
@Test
13+
void createEmbeddingRequestWithSingleToken() {
14+
var request = new OpenAiEmbeddingRequest("token1");
15+
var lowLevelRequest = request.toEmbeddingsCreateRequest();
16+
17+
assertThat(((EmbeddingsCreateRequestInput.InnerString) lowLevelRequest.getInput()))
18+
.usingRecursiveComparison()
19+
.isEqualTo(EmbeddingsCreateRequestInput.create("token1"));
20+
}
21+
22+
@Test
23+
void createEmbeddingRequestWithMultipleTokens() {
24+
var request = new OpenAiEmbeddingRequest("token1", "token2", "token3");
25+
var lowLevelRequest = request.toEmbeddingsCreateRequest();
26+
27+
assertThat(((EmbeddingsCreateRequestInput.InnerStrings) lowLevelRequest.getInput()).values())
28+
.containsExactly("token1", "token2", "token3");
29+
}
30+
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/NewOpenAiClientTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ void history() {
241241
void embedding() {
242242
stubForEmbedding();
243243

244-
final var result = client.embedding("Hello World");
244+
final var result = client.embedding(new OpenAiEmbeddingRequest("Hello World"));
245245

246246
assertThat(result).isNotNull();
247247
assertThat(result.getModel()).isEqualTo("ada");

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/NewOpenAiService.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010
import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest;
1111
import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionResponse;
1212
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
13+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingRequest;
1314
import com.sap.ai.sdk.foundationmodels.openai.OpenAiImageItem;
1415
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
1516
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionNamedToolChoice;
1617
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionNamedToolChoiceFunction;
1718
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionTool;
1819
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionToolChoiceOption;
1920
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreate200Response;
20-
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreateRequest;
21-
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreateRequestInput;
2221
import com.sap.ai.sdk.foundationmodels.openai.model2.FunctionObject;
2322
import java.util.List;
2423
import java.util.Map;
@@ -126,8 +125,7 @@ public OpenAiChatCompletionResponse chatCompletionTools(@Nonnull final String de
126125
*/
127126
@Nonnull
128127
public EmbeddingsCreate200Response embedding(@Nonnull final String input) {
129-
final var request =
130-
new EmbeddingsCreateRequest().input(EmbeddingsCreateRequestInput.create(input));
128+
final var request = new OpenAiEmbeddingRequest(input);
131129

132130
return OpenAiClient.forModel(TEXT_EMBEDDING_ADA_002).embedding(request);
133131
}

0 commit comments

Comments
 (0)