Skip to content

Commit fdc292b

Browse files
committed
Support azure proxy
1 parent 6018b29 commit fdc292b

File tree

14 files changed

+323
-43
lines changed

14 files changed

+323
-43
lines changed

.github/workflows/checker.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ jobs:
7070
17
7171
distribution: 'temurin'
7272
- run: chmod 755 ./mvnw
73-
- run: ./mvnw clean install test -Dfindbugs.skip -Dcheckstyle.skip -Dgpg.skip -Dskip.yarn -Dopenai.token=${{ secrets.OPENAI_TOKEN }} -Dproxy.token=${{ secrets.PROXY_TOKEN }} -Dproxy.host=${{ secrets.PROXY_HOST }}
73+
- run: ./mvnw clean install test -Dfindbugs.skip -Dcheckstyle.skip -Dgpg.skip -Dskip.yarn \
74+
-Dopenai.token=${{ secrets.OPENAI_TOKEN }} \
75+
-Dproxy.token=${{ secrets.PROXY_TOKEN }} \
76+
-Dproxy.host=${{ secrets.PROXY_HOST }} \
77+
-Dazure.token=${{ secrets.AZURE_TOKEN}}
7478

7579
before_checker_package:
7680
runs-on: ubuntu-latest

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
<groupId>org.devlive.sdk</groupId>
77
<artifactId>openai-java-sdk</artifactId>
8-
<version>1.2.0</version>
8+
<version>1.3.0-SNAPSHOT</version>
99

1010
<name>openai-java-sdk</name>
1111
<description>

src/main/java/org/devlive/sdk/openai/DefaultApi.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
import retrofit2.http.GET;
1414
import retrofit2.http.POST;
1515
import retrofit2.http.Path;
16+
import retrofit2.http.Url;
1617

1718
public interface DefaultApi
1819
{
1920
/**
2021
* Lists the currently available models
2122
*/
22-
@GET(value = "v1/models")
23-
Single<ModelResponse> fetchModels();
23+
@GET
24+
Single<ModelResponse> fetchModels(@Url String url);
2425

2526
/**
2627
* Retrieves a model instance, providing basic information about the model such as the owner and permissioning.
@@ -33,14 +34,16 @@ public interface DefaultApi
3334
/**
3435
* Creates a completion for the provided prompt and parameters.
3536
*/
36-
@POST(value = "v1/completions")
37-
Single<CompleteResponse> fetchCompletions(@Body CompletionEntity configure);
37+
@POST
38+
Single<CompleteResponse> fetchCompletions(@Url String url,
39+
@Body CompletionEntity configure);
3840

3941
/**
4042
* Creates a model response for the given chat conversation.
4143
*/
42-
@POST(value = "v1/chat/completions")
43-
Single<CompleteChatResponse> fetchChatCompletions(@Body CompletionChatEntity configure);
44+
@POST
45+
Single<CompleteChatResponse> fetchChatCompletions(@Url String url,
46+
@Body CompletionChatEntity configure);
4447

4548
/**
4649
* Get all keys

src/main/java/org/devlive/sdk/openai/DefaultClient.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
11
package org.devlive.sdk.openai;
22

3+
import lombok.extern.slf4j.Slf4j;
34
import org.devlive.sdk.openai.entity.CompletionChatEntity;
45
import org.devlive.sdk.openai.entity.CompletionEntity;
56
import org.devlive.sdk.openai.entity.ModelEntity;
67
import org.devlive.sdk.openai.entity.UserKeyEntity;
8+
import org.devlive.sdk.openai.model.ProviderModel;
9+
import org.devlive.sdk.openai.model.UrlModel;
710
import org.devlive.sdk.openai.response.CompleteChatResponse;
811
import org.devlive.sdk.openai.response.CompleteResponse;
912
import org.devlive.sdk.openai.response.ModelResponse;
1013
import org.devlive.sdk.openai.response.UserKeyResponse;
14+
import org.devlive.sdk.openai.utils.ProviderUtils;
1115

16+
@Slf4j
1217
public abstract class DefaultClient
1318
{
1419
protected DefaultApi api;
20+
protected ProviderModel provider;
1521

1622
public ModelResponse getModels()
1723
{
18-
return this.api.fetchModels()
24+
return this.api.fetchModels(ProviderUtils.getUrl(provider, UrlModel.FETCH_MODELS))
1925
.blockingGet();
2026
}
2127

@@ -27,13 +33,13 @@ public ModelEntity getModel(String model)
2733

2834
public CompleteResponse createCompletion(CompletionEntity configure)
2935
{
30-
return this.api.fetchCompletions(configure)
36+
return this.api.fetchCompletions(ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS), configure)
3137
.blockingGet();
3238
}
3339

3440
public CompleteChatResponse createChatCompletion(CompletionChatEntity configure)
3541
{
36-
return this.api.fetchChatCompletions(configure)
42+
return this.api.fetchChatCompletions(ProviderUtils.getUrl(provider, UrlModel.FETCH_CHAT_COMPLETIONS), configure)
3743
.blockingGet();
3844
}
3945

src/main/java/org/devlive/sdk/openai/OpenAiClient.java

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
import com.fasterxml.jackson.annotation.JsonInclude;
44
import com.fasterxml.jackson.databind.ObjectMapper;
5-
import com.google.common.base.Preconditions;
65
import lombok.Builder;
76
import lombok.extern.slf4j.Slf4j;
87
import okhttp3.OkHttpClient;
98
import org.apache.commons.lang3.ObjectUtils;
109
import org.apache.commons.lang3.StringUtils;
10+
import org.devlive.sdk.openai.exception.ParamException;
11+
import org.devlive.sdk.openai.interceptor.AzureInterceptor;
1112
import org.devlive.sdk.openai.interceptor.DefaultInterceptor;
13+
import org.devlive.sdk.openai.interceptor.OpenAiInterceptor;
14+
import org.devlive.sdk.openai.model.ProviderModel;
1215
import retrofit2.Retrofit;
1316
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
1417
import retrofit2.converter.jackson.JacksonConverterFactory;
@@ -27,14 +30,31 @@ public class OpenAiClient
2730
private Integer timeout;
2831
private TimeUnit unit;
2932
private OkHttpClient client;
33+
private ProviderModel provider;
34+
// Azure provider requires
35+
private String model; // The model name deployed in azure
36+
private String version;
3037

3138
private OpenAiClient(OpenAiClientBuilder builder)
3239
{
3340
boolean hasApiKey = StringUtils.isNotEmpty(builder.apiKey);
3441
if (!hasApiKey) {
3542
log.error("Invalid OpenAi token");
43+
throw new ParamException("Invalid OpenAi token");
44+
}
45+
46+
if (ObjectUtils.isEmpty(builder.provider)) {
47+
builder.provider(ProviderModel.openai);
48+
}
49+
50+
if (builder.provider.equals(ProviderModel.azure)) {
51+
if (ObjectUtils.isEmpty(builder.model)) {
52+
throw new ParamException("Azure provider model not specified");
53+
}
54+
if (ObjectUtils.isEmpty(builder.version)) {
55+
throw new ParamException("Azure provider version not specified");
56+
}
3657
}
37-
Preconditions.checkState(hasApiKey, "Invalid OpenAi token");
3858

3959
if (ObjectUtils.isEmpty(builder.apiHost)) {
4060
builder.apiHost(null);
@@ -49,6 +69,7 @@ private OpenAiClient(OpenAiClientBuilder builder)
4969
builder.client(null);
5070
}
5171

72+
super.provider = builder.provider;
5273
// Build a remote API client
5374
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
5475
this.api = new Retrofit.Builder()
@@ -74,8 +95,10 @@ public OpenAiClientBuilder apiHost(String apiHost)
7495
apiHost = "https://api.openai.com";
7596
}
7697
else {
77-
Preconditions.checkState(apiHost.startsWith("http") || apiHost.startsWith("https"),
78-
"Api host must start with http or https");
98+
boolean flag = apiHost.startsWith("http") || apiHost.startsWith("https");
99+
if (!flag) {
100+
throw new ParamException(String.format("Invalid apiHost <%s> must start with http or https", apiHost));
101+
}
79102
}
80103
this.apiHost = apiHost;
81104
return this;
@@ -101,8 +124,12 @@ public OpenAiClientBuilder unit(TimeUnit unit)
101124

102125
public OpenAiClientBuilder client(OkHttpClient client)
103126
{
127+
if (ObjectUtils.isEmpty(this.provider)) {
128+
this.provider = ProviderModel.openai;
129+
}
130+
104131
if (ObjectUtils.isEmpty(client)) {
105-
log.warn("No client, creating default client");
132+
log.debug("No client specified, creating default client");
106133
client = new OkHttpClient.Builder()
107134
.connectTimeout(this.timeout, this.unit)
108135
.writeTimeout(this.timeout, this.unit)
@@ -111,8 +138,13 @@ public OpenAiClientBuilder client(OkHttpClient client)
111138
.build();
112139
}
113140
// Add default interceptor
114-
DefaultInterceptor interceptor = new DefaultInterceptor();
115-
interceptor.setApiKey(this.apiKey);
141+
DefaultInterceptor interceptor = new OpenAiInterceptor();
142+
if (this.provider.equals(ProviderModel.azure)) {
143+
interceptor = new AzureInterceptor();
144+
interceptor.setVersion(this.version);
145+
interceptor.setModel(this.model);
146+
}
147+
interceptor.setApiKey(apiKey);
116148
client = client.newBuilder()
117149
.addInterceptor(interceptor)
118150
.build();
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package org.devlive.sdk.openai.interceptor;
2+
3+
import lombok.extern.slf4j.Slf4j;
4+
import okhttp3.HttpUrl;
5+
import okhttp3.Request;
6+
import org.apache.commons.lang3.StringUtils;
7+
import org.devlive.sdk.openai.exception.ParamException;
8+
9+
import java.util.List;
10+
11+
@Slf4j
12+
public class AzureInterceptor
13+
extends DefaultInterceptor
14+
{
15+
public AzureInterceptor()
16+
{
17+
log.debug("Azure Interceptor");
18+
}
19+
20+
@Override
21+
protected Request prepared(Request original)
22+
{
23+
HttpUrl httpUrl = original.url();
24+
List<String> pathSegments = httpUrl.pathSegments();
25+
// Remove all path segments
26+
httpUrl = this.removePathSegment(httpUrl);
27+
// https://${your-resource-name}.openai.azure.com/openai/deployments/${deployment-id}/
28+
pathSegments.add(0, this.getModel());
29+
pathSegments.add(0, "deployments");
30+
pathSegments.add(0, "openai");
31+
httpUrl = httpUrl.newBuilder()
32+
.host(httpUrl.host())
33+
.port(httpUrl.port())
34+
.addPathSegments(String.join("/", pathSegments))
35+
.addQueryParameter("api-version", this.getVersion())
36+
.build();
37+
log.debug("Azure interceptor request url {}", httpUrl.toString());
38+
if (StringUtils.isEmpty(this.getApiKey())) {
39+
throw new ParamException("Invalid OpenAi token, must be non-empty");
40+
}
41+
return original.newBuilder()
42+
.header("api-key", this.getApiKey())
43+
.header("Content-Type", "application/json")
44+
.url(httpUrl)
45+
.method(original.method(), original.body())
46+
.build();
47+
}
48+
49+
private HttpUrl removePathSegment(HttpUrl httpUrl)
50+
{
51+
List<String> pathSegments = httpUrl.pathSegments();
52+
for (int i = 0; i < pathSegments.size(); i++) {
53+
httpUrl = httpUrl.newBuilder()
54+
.removePathSegment(0)
55+
.build();
56+
}
57+
return httpUrl;
58+
}
59+
}

src/main/java/org/devlive/sdk/openai/interceptor/DefaultInterceptor.java

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,32 @@
1111
import okhttp3.ResponseBody;
1212
import okio.Buffer;
1313
import org.apache.commons.lang3.ObjectUtils;
14-
import org.apache.commons.lang3.StringUtils;
1514
import org.devlive.sdk.openai.exception.AuthorizedException;
16-
import org.devlive.sdk.openai.exception.ParamException;
1715
import org.devlive.sdk.openai.exception.RequestException;
1816
import org.devlive.sdk.openai.response.DefaultResponse;
1917
import org.devlive.sdk.openai.utils.JsonUtils;
2018

2119
import java.io.IOException;
2220

2321
@Slf4j
24-
public class DefaultInterceptor
22+
@Setter
23+
@Getter
24+
public abstract class DefaultInterceptor
2525
implements Interceptor
2626
{
27-
@Getter
28-
@Setter
2927
private String apiKey;
28+
private String model;
29+
private String version;
3030

31-
public DefaultInterceptor()
32-
{
33-
log.warn("Default Interceptor");
34-
}
35-
36-
public Request headers(Request original)
37-
{
38-
if (StringUtils.isEmpty(this.apiKey)) {
39-
throw new ParamException("Invalid OpenAi token, must be non-empty");
40-
}
41-
return original.newBuilder()
42-
.header("Authorization", String.format("Bearer %s", this.apiKey))
43-
.header("Content-Type", "application/json")
44-
.method(original.method(), original.body())
45-
.build();
46-
}
31+
protected abstract Request prepared(Request original);
4732

4833
@Override
4934
public Response intercept(Chain chain) throws IOException
5035
{
5136
JsonUtils<DefaultResponse> jsonInstance = JsonUtils.getInstance();
5237

5338
Request original = chain.request();
54-
Request request = this.headers(original);
39+
Request request = this.prepared(original);
5540

5641
RequestBody requestBody = request.body();
5742
if (ObjectUtils.isNotEmpty(requestBody)) {
@@ -82,7 +67,8 @@ public Response intercept(Chain chain) throws IOException
8267
}
8368

8469
// Has error
85-
if (response.code() == 404 || response.code() == 400 || response.code() == 403) {
70+
if (response.code() == 404 || response.code() == 400 || response.code() == 403
71+
|| response.code() == 405) {
8672
ResponseBody body = response.body();
8773
if (ObjectUtils.isEmpty(body)) {
8874
throw new NullPointerException("Failed to intercept request because no body");
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package org.devlive.sdk.openai.interceptor;
2+
3+
import lombok.extern.slf4j.Slf4j;
4+
import okhttp3.Request;
5+
import org.apache.commons.lang3.StringUtils;
6+
import org.devlive.sdk.openai.exception.ParamException;
7+
8+
@Slf4j
9+
public class OpenAiInterceptor
10+
extends DefaultInterceptor
11+
{
12+
public OpenAiInterceptor()
13+
{
14+
log.debug("OpenAi Interceptor");
15+
}
16+
17+
public Request prepared(Request original)
18+
{
19+
log.debug("OpenAi interceptor request url {}", original.url());
20+
if (StringUtils.isEmpty(this.getApiKey())) {
21+
throw new ParamException("Invalid OpenAi token, must be non-empty");
22+
}
23+
return original.newBuilder()
24+
.header("Authorization", String.format("Bearer %s", this.getApiKey()))
25+
.header("Content-Type", "application/json")
26+
.method(original.method(), original.body())
27+
.build();
28+
}
29+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package org.devlive.sdk.openai.model;
2+
3+
public enum ProviderModel
4+
{
5+
openai,
6+
azure
7+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package org.devlive.sdk.openai.model;
2+
3+
public enum UrlModel
4+
{
5+
FETCH_MODELS,
6+
FETCH_MODEL,
7+
FETCH_COMPLETIONS,
8+
FETCH_CHAT_COMPLETIONS,
9+
FETCH_USER_API_KEYS,
10+
FETCH_CREATE_USER_API_KEY
11+
}

0 commit comments

Comments
 (0)