Skip to content

Commit 431a88c

Browse files
authored
ChatGLM: support ChatRequestParameters and mark as deprecated (#232)
1 parent 572be33 commit 431a88c

File tree

1 file changed

+39
-18
lines changed
  • models/langchain4j-community-chatglm/src/main/java/dev/langchain4j/community/model/chatglm

1 file changed

+39
-18
lines changed

models/langchain4j-community-chatglm/src/main/java/dev/langchain4j/community/model/chatglm/ChatGlmChatModel.java

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package dev.langchain4j.community.model.chatglm;
22

3+
import static dev.langchain4j.data.message.AiMessage.from;
34
import static dev.langchain4j.internal.RetryUtils.withRetry;
5+
import static dev.langchain4j.internal.Utils.copy;
46
import static dev.langchain4j.internal.Utils.getOrDefault;
57
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
68
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
@@ -13,7 +15,9 @@
1315
import dev.langchain4j.data.message.SystemMessage;
1416
import dev.langchain4j.data.message.UserMessage;
1517
import dev.langchain4j.model.chat.ChatModel;
18+
import dev.langchain4j.model.chat.listener.ChatModelListener;
1619
import dev.langchain4j.model.chat.request.ChatRequest;
20+
import dev.langchain4j.model.chat.request.ChatRequestParameters;
1721
import dev.langchain4j.model.chat.response.ChatResponse;
1822
import java.time.Duration;
1923
import java.util.ArrayList;
@@ -23,14 +27,16 @@
2327
/**
2428
* Support <a href="https://github.com/THUDM/ChatGLM-6B">ChatGLM</a>,
2529
* ChatGLM2 and ChatGLM3 api are compatible with OpenAI API
30+
*
31+
* @deprecated Please use langchain4j-community-zhipu-ai for more advanced feature instead.
2632
*/
33+
@Deprecated(forRemoval = true)
2734
public class ChatGlmChatModel implements ChatModel {
2835

2936
private final ChatGlmClient client;
30-
private final Double temperature;
31-
private final Double topP;
32-
private final Integer maxLength;
37+
private final List<ChatModelListener> listeners;
3338
private final Integer maxRetries;
39+
private final ChatRequestParameters defaultRequestParameters;
3440

3541
public ChatGlmChatModel(
3642
String baseUrl,
@@ -40,14 +46,17 @@ public ChatGlmChatModel(
4046
Double topP,
4147
Integer maxLength,
4248
boolean logRequests,
43-
boolean logResponses) {
49+
boolean logResponses,
50+
List<ChatModelListener> listeners) {
4451
baseUrl = ensureNotNull(baseUrl, "baseUrl");
4552
timeout = getOrDefault(timeout, ofSeconds(60));
46-
this.temperature = temperature;
4753
this.maxRetries = getOrDefault(maxRetries, 3);
48-
this.topP = topP;
49-
this.maxLength = maxLength;
50-
54+
this.listeners = copy(listeners);
55+
this.defaultRequestParameters = ChatRequestParameters.builder()
56+
.temperature(temperature)
57+
.topP(topP)
58+
.maxOutputTokens(maxLength)
59+
.build();
5160
this.client = ChatGlmClient.builder()
5261
.baseUrl(baseUrl)
5362
.timeout(timeout)
@@ -57,13 +66,19 @@ public ChatGlmChatModel(
5766
}
5867

5968
@Override
60-
public ChatResponse doChat(ChatRequest chatRequest) {
61-
List<ChatMessage> chatMessages = chatRequest.messages();
62-
return ChatResponse.builder().aiMessage(doChat(chatMessages)).build();
69+
public ChatRequestParameters defaultRequestParameters() {
70+
return defaultRequestParameters;
71+
}
72+
73+
@Override
74+
public List<ChatModelListener> listeners() {
75+
return listeners;
6376
}
6477

65-
public AiMessage doChat(List<ChatMessage> messages) {
66-
// get last user message
78+
@Override
79+
public ChatResponse doChat(ChatRequest chatRequest) {
80+
List<ChatMessage> messages = chatRequest.messages();
81+
ChatRequestParameters parameters = chatRequest.parameters();
6782
String prompt;
6883
ChatMessage lastMessage = messages.get(messages.size() - 1);
6984
if (lastMessage instanceof UserMessage userMessage) {
@@ -74,15 +89,15 @@ public AiMessage doChat(List<ChatMessage> messages) {
7489
List<List<String>> history = toHistory(messages.subList(0, messages.size() - 1));
7590
ChatCompletionRequest request = ChatCompletionRequest.builder()
7691
.prompt(prompt)
77-
.temperature(temperature)
78-
.topP(topP)
79-
.maxLength(maxLength)
92+
.temperature(parameters.temperature())
93+
.topP(parameters.topP())
94+
.maxLength(parameters.maxOutputTokens())
8095
.history(history)
8196
.build();
8297

8398
ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request), maxRetries);
8499

85-
return AiMessage.from(response.getResponse());
100+
return ChatResponse.builder().aiMessage(from(response.getResponse())).build();
86101
}
87102

88103
private List<List<String>> toHistory(List<ChatMessage> historyMessages) {
@@ -138,6 +153,7 @@ public static class ChatGlmChatModelBuilder {
138153
private Integer maxLength;
139154
private boolean logRequests;
140155
private boolean logResponses;
156+
private List<ChatModelListener> listeners;
141157

142158
public ChatGlmChatModelBuilder() {
143159
// This is public so it can be extended
@@ -184,9 +200,14 @@ public ChatGlmChatModelBuilder logResponses(boolean logResponses) {
184200
return this;
185201
}
186202

203+
public ChatGlmChatModelBuilder listeners(List<ChatModelListener> listeners) {
204+
this.listeners = listeners;
205+
return this;
206+
}
207+
187208
public ChatGlmChatModel build() {
188209
return new ChatGlmChatModel(
189-
baseUrl, timeout, temperature, maxRetries, topP, maxLength, logRequests, logResponses);
210+
baseUrl, timeout, temperature, maxRetries, topP, maxLength, logRequests, logResponses, listeners);
190211
}
191212
}
192213
}

0 commit comments

Comments
 (0)