Skip to content

Commit a5835ec

Browse files
authored
Support embeddings (#16)
2 parents 54e4942 + cc0b895 commit a5835ec

File tree

9 files changed

+220
-5
lines changed

9 files changed

+220
-5
lines changed

src/main/java/org/devlive/sdk/openai/DefaultApi.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import okhttp3.RequestBody;
66
import org.devlive.sdk.openai.entity.CompletionChatEntity;
77
import org.devlive.sdk.openai.entity.CompletionEntity;
8+
import org.devlive.sdk.openai.entity.EmbeddingEntity;
89
import org.devlive.sdk.openai.entity.ImageEntity;
910
import org.devlive.sdk.openai.entity.ModelEntity;
1011
import org.devlive.sdk.openai.entity.UserKeyEntity;
1112
import org.devlive.sdk.openai.response.CompleteChatResponse;
1213
import org.devlive.sdk.openai.response.CompleteResponse;
14+
import org.devlive.sdk.openai.response.EmbeddingResponse;
1315
import org.devlive.sdk.openai.response.ImageResponse;
1416
import org.devlive.sdk.openai.response.ModelResponse;
1517
import org.devlive.sdk.openai.response.UserKeyResponse;
@@ -91,4 +93,11 @@ Single<ImageResponse> fetchImagesEdits(@Url String url,
9193
Single<ImageResponse> fetchImagesVariations(@Url String url,
9294
@Part() MultipartBody.Part image,
9395
@PartMap Map<String, RequestBody> configure);
96+
97+
/**
98+
* Creates an embedding vector representing the input text.
99+
*/
100+
@POST
101+
Single<EmbeddingResponse> fetchEmbeddings(@Url String url,
102+
@Body EmbeddingEntity configure);
94103
}

src/main/java/org/devlive/sdk/openai/DefaultClient.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import org.apache.commons.lang3.ObjectUtils;
77
import org.devlive.sdk.openai.entity.CompletionChatEntity;
88
import org.devlive.sdk.openai.entity.CompletionEntity;
9+
import org.devlive.sdk.openai.entity.EmbeddingEntity;
910
import org.devlive.sdk.openai.entity.ImageEntity;
1011
import org.devlive.sdk.openai.entity.ModelEntity;
1112
import org.devlive.sdk.openai.entity.UserKeyEntity;
1213
import org.devlive.sdk.openai.model.ProviderModel;
1314
import org.devlive.sdk.openai.model.UrlModel;
1415
import org.devlive.sdk.openai.response.CompleteChatResponse;
1516
import org.devlive.sdk.openai.response.CompleteResponse;
17+
import org.devlive.sdk.openai.response.EmbeddingResponse;
1618
import org.devlive.sdk.openai.response.ImageResponse;
1719
import org.devlive.sdk.openai.response.ModelResponse;
1820
import org.devlive.sdk.openai.response.UserKeyResponse;
@@ -93,6 +95,13 @@ public ImageResponse variationsImages(ImageEntity configure)
9395
.blockingGet();
9496
}
9597

98+
public EmbeddingResponse createEmbeddings(EmbeddingEntity configure)
99+
{
100+
return this.api.fetchEmbeddings(ProviderUtils.getUrl(provider, UrlModel.FETCH_EMBEDDINGS),
101+
configure)
102+
.blockingGet();
103+
}
104+
96105
public void close()
97106
{
98107
if (ObjectUtils.isNotEmpty(this.client)) {
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package org.devlive.sdk.openai.entity;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import lombok.AllArgsConstructor;
6+
import lombok.Builder;
7+
import lombok.Data;
8+
import lombok.NoArgsConstructor;
9+
import lombok.ToString;
10+
import org.apache.commons.lang3.StringUtils;
11+
import org.devlive.sdk.openai.exception.ParamException;
12+
13+
import java.util.List;
14+
15+
@Data
16+
@Builder
17+
@ToString
18+
@NoArgsConstructor
19+
@AllArgsConstructor
20+
@JsonIgnoreProperties(ignoreUnknown = true)
21+
public class EmbeddingEntity
22+
{
23+
/**
24+
* ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them.
25+
*/
26+
@JsonProperty(value = "model")
27+
private String model;
28+
29+
/**
30+
* Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. Each input must not exceed the max input tokens for the model (8191 tokens for text-embedding-ada-002).
31+
*/
32+
@JsonProperty(value = "input")
33+
private String input;
34+
35+
/**
36+
* A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
37+
*/
38+
@JsonProperty(value = "user")
39+
private String user;
40+
41+
/* ====================== Response ====================== */
42+
@JsonProperty(value = "object")
43+
private String object;
44+
45+
@JsonProperty(value = "embedding")
46+
private List<Object> embeddings;
47+
48+
@JsonProperty(value = "index")
49+
private long index;
50+
51+
private EmbeddingEntity(EmbeddingEntityBuilder builder)
52+
{
53+
if (StringUtils.isEmpty(builder.model)) {
54+
builder.model(null);
55+
}
56+
this.model = builder.model;
57+
58+
if (StringUtils.isEmpty(builder.input)) {
59+
builder.input(null);
60+
}
61+
this.input = builder.input;
62+
63+
this.user = builder.user;
64+
}
65+
66+
public static class EmbeddingEntityBuilder
67+
{
68+
public EmbeddingEntityBuilder model(String model)
69+
{
70+
if (!model.equals("text-embedding-ada-002")
71+
&& !(model.startsWith("text-similarity-") && model.endsWith("-001"))
72+
&& !(model.startsWith("text-search-") && model.endsWith("-001"))
73+
&& !(model.startsWith("code-search-") && model.endsWith("-001"))) {
74+
throw new ParamException(String.format("Invalid model %s must be specified, Support text-embedding-ada-002, text-similarity-*-001, text-search-*-*-001, code-search-*-*-001", model));
75+
}
76+
this.model = model;
77+
return this;
78+
}
79+
80+
public EmbeddingEntityBuilder input(String input)
81+
{
82+
if (StringUtils.isEmpty(input)) {
83+
throw new ParamException("Invalid input must be not empty");
84+
}
85+
this.input = input;
86+
return this;
87+
}
88+
89+
public EmbeddingEntity build()
90+
{
91+
return new EmbeddingEntity(this);
92+
}
93+
}
94+
}

src/main/java/org/devlive/sdk/openai/model/CompletionModel.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,32 @@ public enum CompletionModel
6060
CODE_DAVINCI_002("code-davinci-002",
6161
"Optimized for code-completion tasks",
6262
null,
63-
8001);
63+
8001),
64+
65+
TEXT_MODERATION_LATEST("text-moderation-latest",
66+
"Most capable moderation model. Accuracy will be slighlty higher than the stable model.",
67+
null,
68+
Integer.MAX_VALUE),
69+
TEXT_MODERATION_STABLE("text-moderation-stable",
70+
"Almost as capable as the latest model, but slightly older.\n",
71+
null,
72+
Integer.MAX_VALUE),
73+
DAVINCI("davinci",
74+
"Most capable GPT-3 model. Can do any task the other models can do, often with higher quality.",
75+
null,
76+
2049),
77+
CURIE("curie",
78+
"Very capable, but faster and lower cost than Davinci.",
79+
null,
80+
2049),
81+
BABBAGE("babbage",
82+
"Capable of straightforward tasks, very fast, and lower cost.",
83+
null,
84+
2049),
85+
ADA("ada",
86+
"Capable of very simple tasks, usually the fastest model in the GPT-3 series, and lowest cost.",
87+
null,
88+
2049);
6489

6590
private final String name;
6691
private final String description;

src/main/java/org/devlive/sdk/openai/model/UrlModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ public enum UrlModel
1010
FETCH_CREATE_USER_API_KEY,
1111
FETCH_IMAGES_GENERATIONS,
1212
FETCH_IMAGES_EDITS,
13-
FETCH_IMAGES_VARIATIONS
13+
FETCH_IMAGES_VARIATIONS,
14+
FETCH_EMBEDDINGS
1415
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package org.devlive.sdk.openai.response;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import lombok.AllArgsConstructor;
6+
import lombok.Builder;
7+
import lombok.Data;
8+
import lombok.NoArgsConstructor;
9+
import lombok.ToString;
10+
import org.devlive.sdk.openai.entity.EmbeddingEntity;
11+
import org.devlive.sdk.openai.entity.UsageEntity;
12+
13+
import java.util.List;
14+
15+
@Data
16+
@Builder
17+
@ToString
18+
@NoArgsConstructor
19+
@AllArgsConstructor
20+
@JsonIgnoreProperties(ignoreUnknown = true)
21+
public class EmbeddingResponse
22+
{
23+
@JsonProperty(value = "object")
24+
private String object;
25+
26+
@JsonProperty(value = "data")
27+
private List<EmbeddingEntity> embeddings;
28+
29+
@JsonProperty(value = "model")
30+
private String model;
31+
32+
@JsonProperty(value = "usage")
33+
private UsageEntity usage;
34+
}

src/main/java/org/devlive/sdk/openai/utils/ProviderUtils.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ public class ProviderUtils
1717
DEFAULT_PROVIDER.put(UrlModel.FETCH_MODELS, "v1/models");
1818
DEFAULT_PROVIDER.put(UrlModel.FETCH_MODEL, "v1/models/{model}");
1919
DEFAULT_PROVIDER.put(UrlModel.FETCH_COMPLETIONS, "v1/completions");
20-
AZURE_PROVIDER.put(UrlModel.FETCH_COMPLETIONS, "completions");
2120
DEFAULT_PROVIDER.put(UrlModel.FETCH_CHAT_COMPLETIONS, "v1/chat/completions");
22-
AZURE_PROVIDER.put(UrlModel.FETCH_CHAT_COMPLETIONS, "chat/completions");
2321
DEFAULT_PROVIDER.put(UrlModel.FETCH_IMAGES_GENERATIONS, "v1/images/generations");
2422
DEFAULT_PROVIDER.put(UrlModel.FETCH_IMAGES_EDITS, "v1/images/edits");
2523
DEFAULT_PROVIDER.put(UrlModel.FETCH_IMAGES_VARIATIONS, "v1/images/variations");
24+
DEFAULT_PROVIDER.put(UrlModel.FETCH_EMBEDDINGS, "v1/embeddings");
25+
26+
AZURE_PROVIDER.put(UrlModel.FETCH_COMPLETIONS, "completions");
27+
AZURE_PROVIDER.put(UrlModel.FETCH_CHAT_COMPLETIONS, "chat/completions");
2628
}
2729

2830
private ProviderUtils()

src/test/java/org/devlive/sdk/openai/OpenAiClientTest.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.devlive.sdk.openai.entity.CompletionChatEntity;
66
import org.devlive.sdk.openai.entity.CompletionEntity;
77
import org.devlive.sdk.openai.entity.CompletionMessageEntity;
8+
import org.devlive.sdk.openai.entity.EmbeddingEntity;
89
import org.devlive.sdk.openai.entity.ImageEntity;
910
import org.devlive.sdk.openai.entity.UserKeyEntity;
1011
import org.devlive.sdk.openai.exception.AuthorizedException;
@@ -30,7 +31,6 @@ public void before()
3031
{
3132
client = OpenAiClient.builder()
3233
.apiKey(System.getProperty("openai.token"))
33-
.apiKey("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")
3434
.build();
3535
}
3636

@@ -165,4 +165,14 @@ public void testVariationsImages()
165165
.build();
166166
Assert.assertTrue(client.variationsImages(configure).getImages().size() > 0);
167167
}
168+
169+
@Test
170+
public void testCreateEmbeddings()
171+
{
172+
EmbeddingEntity configure = EmbeddingEntity.builder()
173+
.model("text-similarity-ada-001")
174+
.input("Hello OpenAi Java SDK")
175+
.build();
176+
Assert.assertTrue(client.createEmbeddings(configure).getEmbeddings().size() > 0);
177+
}
168178
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package org.devlive.sdk.openai.entity;
2+
3+
import org.devlive.sdk.openai.exception.ParamException;
4+
import org.devlive.sdk.openai.model.CompletionModel;
5+
import org.junit.Assert;
6+
import org.junit.Test;
7+
8+
public class EmbeddingEntityTest
9+
{
10+
@Test
11+
public void testModel()
12+
{
13+
Assert.assertThrows(ParamException.class, () -> EmbeddingEntity.builder()
14+
.model("testModel")
15+
.build());
16+
17+
Assert.assertEquals(EmbeddingEntity.builder()
18+
.model("text-similarity-ada-001")
19+
.input("Test")
20+
.build()
21+
.getModel(), "text-similarity-ada-001");
22+
}
23+
24+
@Test
25+
public void testInput()
26+
{
27+
Assert.assertThrows(ParamException.class, () -> EmbeddingEntity.builder()
28+
.model(CompletionModel.BABBAGE.getName())
29+
.build());
30+
}
31+
}

0 commit comments

Comments
 (0)