Skip to content

Commit 36ddcf7

Browse files
committed
[Google] [PaLM] Support chat prompt
1 parent dbca586 commit 36ddcf7

File tree

14 files changed

+292
-38
lines changed

14 files changed

+292
-38
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
---
2+
title: Chat
3+
---
4+
5+
!!! note
6+
7+
Support the google palm, product address: [https://developers.generativeai.google/products/palm](https://developers.generativeai.google/products/palm)
8+
9+
### Create chat
10+
11+
---
12+
13+
Creates a model response for the given chat conversation.
14+
15+
```java
16+
try(OpenAiClient client=OpenAiClient.builder()
17+
.provider(ProviderModel.GOOGLE_PALM)
18+
.model(CompletionModel.CHAT_BISON_001)
19+
.apiKey(System.getProperty("google.token"))
20+
.build())
21+
{
22+
List<MessageEntity> messages = Lists.newArrayList();
23+
messages.add(MessageEntity.builder()
24+
.content("Hello, my name is openai-java-sdk")
25+
.build());
26+
27+
PromptEntity prompt = PromptEntity.builder()
28+
.messages(messages)
29+
.build();
30+
31+
ChatEntity configure = ChatEntity.builder()
32+
.prompt(prompt)
33+
.build();
34+
35+
client.createPaLMChat(configure)
36+
.getCandidates()
37+
.forEach(System.out::println);
38+
}
39+
```
40+
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
---
2+
title: Chat
3+
---
4+
5+
!!! note
6+
7+
支持 google palm,产品地址: [https://developers.generativeai.google/products/palm](https://developers.generativeai.google/products/palm)
8+
9+
### Create chat
10+
11+
---
12+
13+
为给定的聊天对话创建模型响应。
14+
15+
```java
16+
try(OpenAiClient client=OpenAiClient.builder()
17+
.provider(ProviderModel.GOOGLE_PALM)
18+
.model(CompletionModel.CHAT_BISON_001)
19+
.apiKey(System.getProperty("google.token"))
20+
.build())
21+
{
22+
List<MessageEntity> messages = Lists.newArrayList();
23+
messages.add(MessageEntity.builder()
24+
.content("Hello, my name is openai-java-sdk")
25+
.build());
26+
27+
PromptEntity prompt = PromptEntity.builder()
28+
.messages(messages)
29+
.build();
30+
31+
ChatEntity configure = ChatEntity.builder()
32+
.prompt(prompt)
33+
.build();
34+
35+
client.createPaLMChat(configure)
36+
.getCandidates()
37+
.forEach(System.out::println);
38+
}
39+
```

docs/mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,6 @@ nav:
9494
- reference/anthropic/completions.md
9595
- Google PaLM:
9696
- reference/google_palm/completions.md
97+
- reference/google_palm/chat.md
9798
- released.md
9899
- powered_by.md

src/main/java/org/devlive/sdk/openai/DefaultApi.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,28 @@ public interface DefaultApi
5757
Single<CompleteResponse> fetchCompletions(@Url String url,
5858
@Body CompletionEntity configure);
5959

60+
/**
61+
* Fetches the completions for PaLM.
62+
*
63+
* @param url the URL to fetch the completions from
64+
* @param configure the configuration entity for the completions
65+
* @return the complete response containing the fetched completions
66+
*/
6067
@POST
6168
Single<CompleteResponse> fetchPaLMCompletions(@Url String url,
6269
@Body org.devlive.sdk.openai.entity.google.CompletionEntity configure);
6370

71+
/**
72+
* Fetches the PaLM Chat data from the specified URL.
73+
*
74+
* @param url the URL to fetch the data from
75+
* @param configure the configuration of the chat entity
76+
* @return a Single object representing the complete response
77+
*/
78+
@POST
79+
Single<CompleteResponse> fetchPaLMChat(@Url String url,
80+
@Body org.devlive.sdk.openai.entity.google.ChatEntity configure);
81+
6482
/**
6583
* Creates a model response for the given chat conversation.
6684
*/

src/main/java/org/devlive/sdk/openai/DefaultClient.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.devlive.sdk.openai.entity.ModelEntity;
2121
import org.devlive.sdk.openai.entity.ModerationEntity;
2222
import org.devlive.sdk.openai.entity.UserKeyEntity;
23+
import org.devlive.sdk.openai.entity.google.MessageEntity;
2324
import org.devlive.sdk.openai.exception.RequestException;
2425
import org.devlive.sdk.openai.mixin.IgnoreUnknownMixin;
2526
import org.devlive.sdk.openai.model.ProviderModel;
@@ -79,6 +80,17 @@ public CompleteResponse createPaLMCompletion(org.devlive.sdk.openai.entity.googl
7980
.blockingGet();
8081
}
8182

83+
public CompleteResponse createPaLMChat(org.devlive.sdk.openai.entity.google.ChatEntity configure)
84+
{
85+
MessageEntity message = MessageEntity.builder()
86+
.content("NEXT REQUEST")
87+
.build();
88+
configure.getPrompt().getMessages()
89+
.add(message);
90+
return this.api.fetchPaLMChat(ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS), configure)
91+
.blockingGet();
92+
}
93+
8294
public ChatResponse createChatCompletion(ChatEntity configure)
8395
{
8496
String url = ProviderUtils.getUrl(provider, UrlModel.FETCH_CHAT_COMPLETIONS);
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package org.devlive.sdk.openai.entity.google;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import lombok.AllArgsConstructor;
6+
import lombok.Builder;
7+
import lombok.Data;
8+
import lombok.ToString;
9+
10+
@Data
11+
@Builder
12+
@ToString
13+
@AllArgsConstructor
14+
@JsonIgnoreProperties(ignoreUnknown = true)
15+
public class ChatEntity
16+
{
17+
@JsonProperty(value = "prompt")
18+
private PromptEntity prompt;
19+
20+
@JsonProperty(value = "temperature")
21+
@Builder.Default
22+
private Double temperature = 0.25;
23+
24+
@JsonProperty(value = "top_k")
25+
@Builder.Default
26+
private Integer topK = 40;
27+
28+
@JsonProperty(value = "top_p")
29+
@Builder.Default
30+
private Double topP = 1.0;
31+
32+
@JsonProperty(value = "candidate_count")
33+
@Builder.Default
34+
private Integer candidateCount = 1;
35+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package org.devlive.sdk.openai.entity.google;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import lombok.AllArgsConstructor;
6+
import lombok.Builder;
7+
import lombok.Data;
8+
import lombok.ToString;
9+
10+
@Data
11+
@Builder
12+
@ToString
13+
@AllArgsConstructor
14+
@JsonIgnoreProperties(ignoreUnknown = true)
15+
public class ExampleEntity
16+
{
17+
@JsonProperty(value = "input")
18+
private MessageEntity input;
19+
20+
@JsonProperty(value = "output")
21+
private MessageEntity output;
22+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package org.devlive.sdk.openai.entity.google;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import lombok.AllArgsConstructor;
6+
import lombok.Builder;
7+
import lombok.Data;
8+
import lombok.ToString;
9+
10+
@Data
11+
@Builder
12+
@ToString
13+
@AllArgsConstructor
14+
@JsonIgnoreProperties(ignoreUnknown = true)
15+
public class MessageEntity
16+
{
17+
@JsonProperty(value = "content")
18+
private String content;
19+
}

src/main/java/org/devlive/sdk/openai/entity/google/PromptEntity.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import lombok.Data;
88
import lombok.ToString;
99

10+
import java.util.List;
11+
1012
@Data
1113
@Builder
1214
@ToString
@@ -16,4 +18,13 @@ public class PromptEntity
1618
{
1719
@JsonProperty(value = "text")
1820
private String text;
21+
22+
@JsonProperty(value = "context")
23+
private String context;
24+
25+
@JsonProperty(value = "examples")
26+
private List<ExampleEntity> examples;
27+
28+
@JsonProperty(value = "messages")
29+
private List<MessageEntity> messages;
1930
}

src/main/java/org/devlive/sdk/openai/interceptor/GooglePaLMInterceptor.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import okhttp3.Request;
77
import org.apache.commons.lang3.StringUtils;
88
import org.devlive.sdk.openai.exception.ParamException;
9+
import org.devlive.sdk.openai.model.CompletionModel;
10+
import org.devlive.sdk.openai.utils.EnumsUtils;
911
import org.devlive.sdk.openai.utils.HttpUrlUtils;
1012

1113
import java.util.List;
@@ -29,7 +31,7 @@ protected Request prepared(Request original)
2931
List<String> pathSegments = Lists.newArrayList();
3032
httpUrl = HttpUrlUtils.removePathSegment(httpUrl);
3133
// https://generativelanguage.googleapis.com/v1beta2/models/text-bison-001:generateText?key=YOUR_KEY
32-
pathSegments.add(0, String.join(":", this.getModel(), "generateText"));
34+
pathSegments.add(0, String.join(":", this.getModel(), this.getModelType()));
3335
pathSegments.add(0, "models");
3436
pathSegments.add(0, "v1beta2");
3537
httpUrl = httpUrl.newBuilder()
@@ -45,4 +47,22 @@ protected Request prepared(Request original)
4547
.method(original.method(), original.body())
4648
.build();
4749
}
50+
51+
/**
52+
* Retrieves the model type based on the current model value.
53+
*
54+
* @return the model type as a string
55+
*/
56+
private String getModelType()
57+
{
58+
CompletionModel model = EnumsUtils.getCompleteModel(this.getModel());
59+
switch (model) {
60+
case TEXT_BISON_001:
61+
return "generateText";
62+
case CHAT_BISON_001:
63+
return "generateMessage";
64+
default:
65+
throw new ParamException("Unsupported Google PaLM model");
66+
}
67+
}
4868
}

0 commit comments

Comments
 (0)