Skip to content

Commit 20d72c3

Browse files
committed
Request convenience
1 parent 826d351 commit 20d72c3

File tree

4 files changed

+223
-1
lines changed

4 files changed

+223
-1
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ public Stream<OrchestrationChatCompletionDelta> streamChatCompletionDeltas(
226226
* @since 1.9.0
227227
*/
228228
@Nonnull
229-
EmbeddingsPostResponse embed(@Nonnull final EmbeddingsPostRequest request)
229+
public EmbeddingsPostResponse embed(@Nonnull final EmbeddingsPostRequest request)
230230
throws OrchestrationClientException {
231231
return executor.execute("/v2/embeddings", request, EmbeddingsPostResponse.class);
232232
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.google.common.annotations.Beta;
4+
import com.sap.ai.sdk.core.AiModel;
5+
import com.sap.ai.sdk.orchestration.model.EmbeddingsModelDetails;
6+
import com.sap.ai.sdk.orchestration.model.EmbeddingsModelParams;
7+
import javax.annotation.Nonnull;
8+
import javax.annotation.Nullable;
9+
import lombok.AccessLevel;
10+
import lombok.AllArgsConstructor;
11+
import lombok.Value;
12+
import lombok.With;
13+
import lombok.experimental.Accessors;
14+
15+
// ideally this is a record but exposes all args constructor which we want to avoid. Is it worth a
16+
// value class?
17+
// Currently, model list follow SAP Notes as the source of truth. But deprecated models are not
18+
// listed there.
19+
// Can we reuse existing enum from generated class?
20+
@Beta
21+
@With
22+
@Value
23+
@Accessors(fluent = true)
24+
@AllArgsConstructor(access = AccessLevel.PRIVATE)
25+
public class OrchestrationEmbeddingModel implements AiModel {
26+
@Nonnull String name;
27+
@Nullable String version;
28+
@Nullable Integer dimensions;
29+
@Nullable Boolean normalize;
30+
@Nullable EmbeddingsModelParams.EncodingFormatEnum encodingFormat;
31+
32+
public OrchestrationEmbeddingModel(@Nonnull final String name) {
33+
this(name, null, null, null, null);
34+
}
35+
36+
public static final OrchestrationEmbeddingModel TEXT_EMBEDDING_3_SMALL =
37+
new OrchestrationEmbeddingModel("text-embedding-3-small");
38+
39+
public static final OrchestrationEmbeddingModel TEXT_EMBEDDING_3_LARGE =
40+
new OrchestrationEmbeddingModel("text-embedding-3-large");
41+
42+
public static final OrchestrationEmbeddingModel AMAZON_TITAN_EMBED_TEXT =
43+
new OrchestrationEmbeddingModel("amazon.titan-embed-text");
44+
45+
public static final OrchestrationEmbeddingModel NVIDIA_LLAMA_32_NV_EMBEDQA_1B =
46+
new OrchestrationEmbeddingModel("nvidia--llama-3.2-nv-embedqa-1b");
47+
48+
EmbeddingsModelDetails createEmbeddingsModelDetails() {
49+
50+
final var params =
51+
EmbeddingsModelParams.create()
52+
.dimensions(dimensions)
53+
.normalize(normalize)
54+
.encodingFormat(encodingFormat);
55+
return EmbeddingsModelDetails.create().name(name).version(version).params(params);
56+
}
57+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static com.sap.ai.sdk.orchestration.model.EmbeddingsInput.TypeEnum.DOCUMENT;
4+
import static com.sap.ai.sdk.orchestration.model.EmbeddingsInput.TypeEnum.QUERY;
5+
import static com.sap.ai.sdk.orchestration.model.EmbeddingsInput.TypeEnum.TEXT;
6+
import static lombok.AccessLevel.PRIVATE;
7+
8+
import com.google.common.annotations.Beta;
9+
import com.google.common.collect.Lists;
10+
import com.sap.ai.sdk.orchestration.model.EmbeddingsInput;
11+
import com.sap.ai.sdk.orchestration.model.EmbeddingsInputText;
12+
import com.sap.ai.sdk.orchestration.model.EmbeddingsModelConfig;
13+
import com.sap.ai.sdk.orchestration.model.EmbeddingsModuleConfigs;
14+
import com.sap.ai.sdk.orchestration.model.EmbeddingsOrchestrationConfig;
15+
import com.sap.ai.sdk.orchestration.model.EmbeddingsPostRequest;
16+
import com.sap.ai.sdk.orchestration.model.MaskingModuleConfigProviders;
17+
import java.util.List;
18+
import javax.annotation.Nonnull;
19+
import javax.annotation.Nullable;
20+
import lombok.AllArgsConstructor;
21+
import lombok.Value;
22+
import lombok.With;
23+
import lombok.experimental.Tolerate;
24+
25+
// Do we need staged input builder here?
26+
// Do we need an enum for tokenType?
27+
@Beta
28+
@Value
29+
@AllArgsConstructor(access = PRIVATE)
30+
public class OrchestrationEmbeddingRequest {
31+
32+
@Nonnull OrchestrationEmbeddingModel model;
33+
@Nonnull List<String> tokens;
34+
35+
@With(value = PRIVATE)
36+
@Nullable
37+
List<MaskingProvider> masking;
38+
39+
@With(value = PRIVATE)
40+
@Nullable
41+
EmbeddingsInput.TypeEnum tokenType;
42+
43+
public static OrchestrationEmbeddingRequest create(
44+
OrchestrationEmbeddingModel model, List<String> tokens) {
45+
return new OrchestrationEmbeddingRequest(model, tokens, null, null);
46+
}
47+
48+
@Tolerate
49+
@Nonnull
50+
public OrchestrationEmbeddingRequest withMasking(
51+
@Nonnull final MaskingProvider maskingProvider,
52+
@Nonnull final MaskingProvider... maskingProviders) {
53+
return withMasking(Lists.asList(maskingProvider, maskingProviders));
54+
}
55+
56+
@Nonnull
57+
public OrchestrationEmbeddingRequest asDocument() {
58+
return withTokenType(DOCUMENT);
59+
}
60+
61+
@Nonnull
62+
public OrchestrationEmbeddingRequest asText() {
63+
return withTokenType(TEXT);
64+
}
65+
66+
@Nonnull
67+
public OrchestrationEmbeddingRequest asQuery() {
68+
return withTokenType(QUERY);
69+
}
70+
71+
EmbeddingsPostRequest createEmbeddingsPostRequest() {
72+
73+
final var input =
74+
EmbeddingsInput.create().text(EmbeddingsInputText.create(tokens)).type(tokenType);
75+
final var embeddingsModelConfig =
76+
EmbeddingsModelConfig.create().model(this.model.createEmbeddingsModelDetails());
77+
final var modules =
78+
EmbeddingsOrchestrationConfig.create()
79+
.modules(EmbeddingsModuleConfigs.create().embeddings(embeddingsModelConfig));
80+
81+
if (masking != null) {
82+
final var dpiConfigs = this.masking.stream().map(MaskingProvider::createConfig).toList();
83+
modules.getModules().setMasking(MaskingModuleConfigProviders.create().providers(dpiConfigs));
84+
}
85+
86+
return EmbeddingsPostRequest.create().config(modules).input(input);
87+
}
88+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static com.sap.ai.sdk.orchestration.OrchestrationEmbeddingModel.TEXT_EMBEDDING_3_SMALL;
4+
import static com.sap.ai.sdk.orchestration.model.EmbeddingsInput.TypeEnum.DOCUMENT;
5+
import static com.sap.ai.sdk.orchestration.model.EmbeddingsModelParams.EncodingFormatEnum.BASE64;
6+
import static org.assertj.core.api.Assertions.assertThat;
7+
8+
import com.sap.ai.sdk.orchestration.model.DPIConfig;
9+
import com.sap.ai.sdk.orchestration.model.DPIEntities;
10+
import com.sap.ai.sdk.orchestration.model.DPIStandardEntity;
11+
import com.sap.ai.sdk.orchestration.model.EmbeddingsInputText;
12+
import com.sap.ai.sdk.orchestration.model.MaskingModuleConfigProviders;
13+
import java.util.List;
14+
import org.junit.jupiter.api.Test;
15+
16+
class OrchestrationEmbeddingTest {
17+
18+
@Test
19+
void embeddingModelTest() {
20+
final var model = TEXT_EMBEDDING_3_SMALL;
21+
assertThat(model.name().equals("text-embedding-3-small"));
22+
assertThat(model.version()).isNull();
23+
assertThat(model.dimensions()).isNull();
24+
assertThat(model.encodingFormat()).isNull();
25+
assertThat(model.normalize()).isNull();
26+
27+
final var model2 =
28+
TEXT_EMBEDDING_3_SMALL
29+
.withVersion("some-version")
30+
.withDimensions(1536)
31+
.withNormalize(true)
32+
.withEncodingFormat(BASE64);
33+
assertThat(model2.name().equals("text-embedding-3-large"));
34+
assertThat(model2.version().equals("some-version"));
35+
assertThat(model2.dimensions().equals(1536));
36+
assertThat(model2.normalize().equals(true));
37+
assertThat(model2.encodingFormat().equals(BASE64));
38+
39+
final var custom = new OrchestrationEmbeddingModel("custom-model");
40+
assertThat(custom.name()).isEqualTo("custom-model");
41+
}
42+
43+
@Test
44+
void embeddingRequestTest() {
45+
final var request =
46+
OrchestrationEmbeddingRequest.create(TEXT_EMBEDDING_3_SMALL, List.of("token1", "token2"))
47+
.asDocument()
48+
.withMasking(
49+
DpiMasking.anonymization()
50+
.withEntities(DPIEntities.ADDRESS)
51+
.withAllowList(List.of("Alice")));
52+
53+
final var postRequest = request.createEmbeddingsPostRequest();
54+
assertThat(postRequest.getInput().getText())
55+
.isEqualTo(EmbeddingsInputText.create(List.of("token1", "token2")));
56+
assertThat(postRequest.getInput().getType()).isEqualTo(DOCUMENT);
57+
final var embeddingsModelConfig = postRequest.getConfig().getModules().getEmbeddings();
58+
assertThat(embeddingsModelConfig.getModel().getName()).isEqualTo("text-embedding-3-small");
59+
assertThat(embeddingsModelConfig.getModel().getVersion()).isNull();
60+
assertThat(embeddingsModelConfig.getModel().getParams().getDimensions()).isNull();
61+
assertThat(embeddingsModelConfig.getModel().getParams().getEncodingFormat()).isNull();
62+
assertThat(embeddingsModelConfig.getModel().getParams().isNormalize()).isNull();
63+
64+
final var maskingConfig = postRequest.getConfig().getModules().getMasking();
65+
assertThat(maskingConfig)
66+
.isInstanceOfSatisfying(
67+
MaskingModuleConfigProviders.class,
68+
cfg -> {
69+
assertThat(cfg.getProviders()).hasSize(1);
70+
assertThat(cfg.getProviders().get(0).getType())
71+
.isEqualTo(DPIConfig.TypeEnum.SAP_DATA_PRIVACY_INTEGRATION);
72+
assertThat(cfg.getProviders().get(0).getEntities())
73+
.isEqualTo(List.of(DPIStandardEntity.create().type(DPIEntities.ADDRESS)));
74+
assertThat(cfg.getProviders().get(0).getAllowlist()).isEqualTo(List.of("Alice"));
75+
});
76+
}
77+
}

0 commit comments

Comments
 (0)