Skip to content

Commit c82c347

Browse files
committed
Support variations images
1 parent 297ad5e commit c82c347

File tree

8 files changed

+157
-35
lines changed

8 files changed

+157
-35
lines changed

.github/workflows/checker.yml

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,7 @@ jobs:
7171
distribution: 'temurin'
7272
- run: chmod 755 ./mvnw
7373
- run: |
74-
./mvnw clean install test \
75-
-Dfindbugs.skip \
76-
-Dcheckstyle.skip \
77-
-Dgpg.skip -Dskip.yarn \
78-
-Dopenai.token=${{ secrets.OPENAI_TOKEN }} \
79-
-Dproxy.token=${{ secrets.PROXY_TOKEN }} \
80-
-Dproxy.host=${{ secrets.PROXY_HOST }} \
81-
-Dazure.token=${{ secrets.AZURE_TOKEN}}
74+
./mvnw clean install test -Dfindbugs.skip -Dcheckstyle.skip -Dgpg.skip -Dskip.yarn -Dopenai.token=${{ secrets.OPENAI_TOKEN }} -Dproxy.token=${{ secrets.PROXY_TOKEN }} -Dproxy.host=${{ secrets.PROXY_HOST }} -Dazure.token=${{ secrets.AZURE_TOKEN}}
8275
8376
before_checker_package:
8477
runs-on: ubuntu-latest

src/main/java/org/devlive/sdk/openai/DefaultApi.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ Single<CompleteChatResponse> fetchChatCompletions(@Url String url,
7070
* Creates an image given a prompt.
7171
*/
7272
@POST
73-
Single<ImageResponse> fetchImagesGenerations(@Url String url, @Body ImageEntity configure);
73+
Single<ImageResponse> fetchImagesGenerations(@Url String url,
74+
@Body ImageEntity configure);
7475

7576
/**
7677
* Creates an edited or extended image given an original image and a prompt.
@@ -81,4 +82,13 @@ Single<ImageResponse> fetchImagesEdits(@Url String url,
8182
@Part() MultipartBody.Part image,
8283
@Part() MultipartBody.Part mask,
8384
@PartMap Map<String, RequestBody> configure);
85+
86+
/**
87+
* Creates a variation of a given image.
88+
*/
89+
@POST
90+
@Multipart
91+
Single<ImageResponse> fetchImagesVariations(@Url String url,
92+
@Part() MultipartBody.Part image,
93+
@PartMap Map<String, RequestBody> configure);
8494
}

src/main/java/org/devlive/sdk/openai/DefaultClient.java

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
package org.devlive.sdk.openai;
22

3-
import com.google.common.collect.Maps;
43
import lombok.extern.slf4j.Slf4j;
54
import okhttp3.MultipartBody;
65
import okhttp3.OkHttpClient;
7-
import okhttp3.RequestBody;
86
import org.apache.commons.lang3.ObjectUtils;
9-
import org.apache.commons.lang3.StringUtils;
107
import org.devlive.sdk.openai.entity.CompletionChatEntity;
118
import org.devlive.sdk.openai.entity.CompletionEntity;
129
import org.devlive.sdk.openai.entity.ImageEntity;
@@ -22,9 +19,6 @@
2219
import org.devlive.sdk.openai.utils.MultipartBodyUtils;
2320
import org.devlive.sdk.openai.utils.ProviderUtils;
2421

25-
import java.io.File;
26-
import java.util.Map;
27-
2822
@Slf4j
2923
public abstract class DefaultClient implements AutoCloseable
3024
{
@@ -70,31 +64,32 @@ public UserKeyResponse createUserAPIKey(UserKeyEntity configure)
7064

7165
public ImageResponse createImages(ImageEntity configure)
7266
{
67+
configure.setIsVariation(null);
68+
configure.setIsEdit(null);
7369
return this.api.fetchImagesGenerations(ProviderUtils.getUrl(provider, UrlModel.FETCH_IMAGES_GENERATIONS), configure)
7470
.blockingGet();
7571
}
7672

77-
public ImageResponse editImages(File image, File mask, ImageEntity configure)
73+
public ImageResponse editImages(ImageEntity configure)
7874
{
79-
MultipartBody.Part imageBody = MultipartBodyUtils.getPart(image, "image");
75+
MultipartBody.Part imageBody = MultipartBodyUtils.getPart(configure.getImage(), "image");
8076
MultipartBody.Part maskBody = null;
81-
if (ObjectUtils.isNotEmpty(mask)) {
82-
maskBody = MultipartBodyUtils.getPart(mask, "mask");
83-
}
84-
85-
Map<String, RequestBody> map = Maps.newConcurrentMap();
86-
map.put("prompt", RequestBody.create(MultipartBodyUtils.TYPE, configure.getPrompt()));
87-
map.put("n", RequestBody.create(MultipartBodyUtils.TYPE, configure.getCount().toString()));
88-
map.put("size", RequestBody.create(MultipartBodyUtils.TYPE, configure.getSize()));
89-
map.put("response_format", RequestBody.create(MultipartBodyUtils.TYPE, configure.getFormat()));
90-
if (StringUtils.isNotEmpty(configure.getUser())) {
91-
map.put("user", RequestBody.create(MultipartBodyUtils.TYPE, configure.getUser()));
77+
if (ObjectUtils.isNotEmpty(configure.getMask())) {
78+
maskBody = MultipartBodyUtils.getPart(configure.getMask(), "mask");
9279
}
93-
94-
return this.api.fetchImagesEdits(ProviderUtils.getUrl(provider, UrlModel.FETCH_EDITS_GENERATIONS),
80+
return this.api.fetchImagesEdits(ProviderUtils.getUrl(provider, UrlModel.FETCH_IMAGES_EDITS),
9581
imageBody,
9682
maskBody,
97-
map)
83+
configure.convertMap())
84+
.blockingGet();
85+
}
86+
87+
public ImageResponse variationsImages(ImageEntity configure)
88+
{
89+
MultipartBody.Part imageBody = MultipartBodyUtils.getPart(configure.getImage(), "image");
90+
return this.api.fetchImagesVariations(ProviderUtils.getUrl(provider, UrlModel.FETCH_IMAGES_VARIATIONS),
91+
imageBody,
92+
configure.convertMap())
9893
.blockingGet();
9994
}
10095

src/main/java/org/devlive/sdk/openai/entity/ImageEntity.java

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,24 @@
22

33
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
44
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import com.google.common.collect.Maps;
56
import lombok.AllArgsConstructor;
67
import lombok.Builder;
78
import lombok.Data;
89
import lombok.NoArgsConstructor;
910
import lombok.ToString;
11+
import okhttp3.RequestBody;
1012
import org.apache.commons.lang3.EnumUtils;
1113
import org.apache.commons.lang3.ObjectUtils;
1214
import org.apache.commons.lang3.StringUtils;
1315
import org.devlive.sdk.openai.exception.ParamException;
1416
import org.devlive.sdk.openai.model.ImageFormatModel;
1517
import org.devlive.sdk.openai.model.ImageSizeModel;
18+
import org.devlive.sdk.openai.utils.MultipartBodyUtils;
1619

20+
import java.io.File;
1721
import java.util.Arrays;
22+
import java.util.Map;
1823

1924
@Data
2025
@Builder
@@ -61,6 +66,31 @@ public class ImageEntity
6166
@JsonProperty(value = "url")
6267
private String url;
6368

69+
@JsonProperty(value = "image")
70+
private File image;
71+
72+
@JsonProperty(value = "mask")
73+
private File mask;
74+
75+
private Boolean isEdit;
76+
private Boolean isVariation;
77+
78+
public Map<String, RequestBody> convertMap()
79+
{
80+
Map<String, RequestBody> map = Maps.newConcurrentMap();
81+
if (this.isEdit) {
82+
map.put("prompt", RequestBody.create(MultipartBodyUtils.TYPE, this.getPrompt()));
83+
}
84+
map.put("n", RequestBody.create(MultipartBodyUtils.TYPE, this.getCount().toString()));
85+
map.put("size", RequestBody.create(MultipartBodyUtils.TYPE, this.getSize()));
86+
map.put("response_format", RequestBody.create(MultipartBodyUtils.TYPE, this.getFormat()));
87+
88+
if (StringUtils.isNotEmpty(this.getUser())) {
89+
map.put("user", RequestBody.create(MultipartBodyUtils.TYPE, this.getUser()));
90+
}
91+
return map;
92+
}
93+
6494
private ImageEntity(ImageEntityBuilder builder)
6595
{
6696
if (ObjectUtils.isEmpty(builder.prompt)) {
@@ -83,14 +113,34 @@ private ImageEntity(ImageEntityBuilder builder)
83113
}
84114
this.format = builder.format;
85115

116+
if (ObjectUtils.isEmpty(builder.image)) {
117+
builder.image(null);
118+
}
119+
this.image = builder.image;
120+
121+
if (ObjectUtils.isEmpty(builder.mask)) {
122+
builder.mask(null);
123+
}
124+
this.mask = builder.mask;
125+
126+
if (ObjectUtils.isEmpty(builder.isEdit)) {
127+
builder.isEdit(Boolean.FALSE);
128+
}
129+
this.isEdit = builder.isEdit;
130+
131+
if (ObjectUtils.isEmpty(builder.isVariation)) {
132+
builder.isVariation(Boolean.FALSE);
133+
}
134+
this.isVariation = builder.isVariation;
135+
86136
this.user = builder.user;
87137
}
88138

89139
public static class ImageEntityBuilder
90140
{
91141
public ImageEntityBuilder prompt(String prompt)
92142
{
93-
if (StringUtils.isEmpty(prompt)) {
143+
if ((ObjectUtils.isEmpty(this.isVariation) || !this.isVariation) && StringUtils.isEmpty(prompt)) {
94144
throw new ParamException("Invalid prompt must not be empty");
95145
}
96146
this.prompt = prompt;
@@ -126,6 +176,46 @@ public ImageEntityBuilder format(ImageFormatModel format)
126176
return this;
127177
}
128178

179+
public ImageEntityBuilder image(File image)
180+
{
181+
if (ObjectUtils.isNotEmpty(image) && image.length() > 4 * 1024 * 102) {
182+
throw new ParamException("Must be less than 4MB");
183+
}
184+
this.image = image;
185+
return this;
186+
}
187+
188+
public ImageEntityBuilder mask(File mask)
189+
{
190+
if (ObjectUtils.isNotEmpty(mask) && mask.length() > 4 * 1024 * 102) {
191+
throw new ParamException("Must be less than 4MB");
192+
}
193+
this.mask = mask;
194+
return this;
195+
}
196+
197+
public ImageEntityBuilder isEdit(Boolean isEdit)
198+
{
199+
if (isEdit && ObjectUtils.isEmpty(this.image)) {
200+
throw new ParamException("Image must not be empty.");
201+
}
202+
this.isEdit = isEdit;
203+
return this;
204+
}
205+
206+
public ImageEntityBuilder isVariation(Boolean isVariation)
207+
{
208+
if (isVariation && ObjectUtils.isNotEmpty(this.prompt)) {
209+
throw new ParamException("Please remove prompt");
210+
}
211+
212+
if (isVariation && ObjectUtils.isEmpty(this.image)) {
213+
throw new ParamException("Image must not be empty.");
214+
}
215+
this.isVariation = isVariation;
216+
return this;
217+
}
218+
129219
public ImageEntity build()
130220
{
131221
return new ImageEntity(this);

src/main/java/org/devlive/sdk/openai/model/UrlModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ public enum UrlModel
99
FETCH_USER_API_KEYS,
1010
FETCH_CREATE_USER_API_KEY,
1111
FETCH_IMAGES_GENERATIONS,
12-
FETCH_EDITS_GENERATIONS
12+
FETCH_IMAGES_EDITS,
13+
FETCH_IMAGES_VARIATIONS
1314
}

src/main/java/org/devlive/sdk/openai/utils/ProviderUtils.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ public class ProviderUtils
2121
DEFAULT_PROVIDER.put(UrlModel.FETCH_CHAT_COMPLETIONS, "v1/chat/completions");
2222
AZURE_PROVIDER.put(UrlModel.FETCH_CHAT_COMPLETIONS, "chat/completions");
2323
DEFAULT_PROVIDER.put(UrlModel.FETCH_IMAGES_GENERATIONS, "v1/images/generations");
24-
DEFAULT_PROVIDER.put(UrlModel.FETCH_EDITS_GENERATIONS, "v1/images/edits");
24+
DEFAULT_PROVIDER.put(UrlModel.FETCH_IMAGES_EDITS, "v1/images/edits");
25+
DEFAULT_PROVIDER.put(UrlModel.FETCH_IMAGES_VARIATIONS, "v1/images/variations");
2526
}
2627

2728
private ProviderUtils()

src/test/java/org/devlive/sdk/openai/OpenAiClientTest.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public void before()
3030
{
3131
client = OpenAiClient.builder()
3232
.apiKey(System.getProperty("openai.token"))
33+
.apiKey("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")
3334
.build();
3435
}
3536

@@ -145,10 +146,23 @@ public void testCreateImages()
145146
@Test
146147
public void testEditImages()
147148
{
149+
String file = this.getClass().getResource("/logo.png").getFile();
148150
ImageEntity configure = ImageEntity.builder()
149151
.prompt("Add hello to image")
152+
.image(new File(file))
153+
.isEdit(Boolean.TRUE)
150154
.build();
155+
Assert.assertTrue(client.editImages(configure).getImages().size() > 0);
156+
}
157+
158+
@Test
159+
public void testVariationsImages()
160+
{
151161
String file = this.getClass().getResource("/logo.png").getFile();
152-
Assert.assertTrue(client.editImages(new File(file), null, configure).getImages().size() > 0);
162+
ImageEntity configure = ImageEntity.builder()
163+
.image(new File(file))
164+
.isVariation(Boolean.TRUE)
165+
.build();
166+
Assert.assertTrue(client.variationsImages(configure).getImages().size() > 0);
153167
}
154168
}

src/test/java/org/devlive/sdk/openai/entity/ImageEntityTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import org.junit.Assert;
66
import org.junit.Test;
77

8+
import java.io.File;
9+
810
public class ImageEntityTest
911
{
1012
private String prompt = "Enter";
@@ -35,4 +37,20 @@ public void testSize()
3537
.build();
3638
Assert.assertEquals(entity.getSize(), ImageSizeModel.X_256.getName());
3739
}
40+
41+
@Test
42+
public void testImage()
43+
{
44+
Assert.assertThrows(ParamException.class, () -> ImageEntity.builder()
45+
.isEdit(Boolean.TRUE)
46+
.build());
47+
48+
String file = this.getClass().getResource("/logo.png").getFile();
49+
Assert.assertTrue(ImageEntity.builder()
50+
.prompt("Testing")
51+
.image(new File(file))
52+
.isEdit(Boolean.TRUE)
53+
.build()
54+
.getIsEdit());
55+
}
3856
}

0 commit comments

Comments
 (0)