Skip to content

Commit f480767

Browse files
authored
Xinference: support ChatRequestParameters (#233)
1 parent dcb1615 commit f480767

File tree

2 files changed

+68
-48
lines changed

2 files changed

+68
-48
lines changed

models/langchain4j-community-xinference/src/main/java/dev/langchain4j/community/model/xinference/XinferenceChatModel.java

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import static dev.langchain4j.community.model.xinference.InternalXinferenceHelper.toXinferenceMessages;
77
import static dev.langchain4j.community.model.xinference.InternalXinferenceHelper.tokenUsageFrom;
88
import static dev.langchain4j.internal.RetryUtils.withRetry;
9+
import static dev.langchain4j.internal.Utils.copy;
910
import static dev.langchain4j.internal.Utils.getOrDefault;
1011
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
1112
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
@@ -39,19 +40,16 @@ public class XinferenceChatModel implements ChatModel {
3940
private static final Logger log = LoggerFactory.getLogger(XinferenceChatModel.class);
4041

4142
private final XinferenceClient client;
42-
private final String modelName;
43-
private final Double temperature;
44-
private final Double topP;
45-
private final List<String> stop;
46-
private final Integer maxTokens;
47-
private final Double presencePenalty;
48-
private final Double frequencyPenalty;
43+
private final Integer maxRetries;
44+
private final List<ChatModelListener> listeners;
45+
private final ChatRequestParameters defaultRequestParameters;
46+
47+
/* TODO: support custom ChatRequestParameters */
48+
4949
private final Integer seed;
5050
private final String user;
5151
private final Object toolChoice;
5252
private final Boolean parallelToolCalls;
53-
private final Integer maxRetries;
54-
private final List<ChatModelListener> listeners;
5553

5654
public XinferenceChatModel(
5755
String baseUrl,
@@ -75,6 +73,8 @@ public XinferenceChatModel(
7573
Map<String, String> customHeaders,
7674
List<ChatModelListener> listeners) {
7775
timeout = getOrDefault(timeout, Duration.ofSeconds(60));
76+
this.maxRetries = getOrDefault(maxRetries, 3);
77+
this.listeners = copy(listeners);
7878

7979
this.client = XinferenceClient.builder()
8080
.baseUrl(baseUrl)
@@ -88,20 +88,30 @@ public XinferenceChatModel(
8888
.logResponses(logResponses)
8989
.customHeaders(customHeaders)
9090
.build();
91+
this.defaultRequestParameters = ChatRequestParameters.builder()
92+
.modelName(ensureNotBlank(modelName, "modelName"))
93+
.temperature(temperature)
94+
.topP(topP)
95+
.stopSequences(stop)
96+
.maxOutputTokens(maxTokens)
97+
.presencePenalty(presencePenalty)
98+
.frequencyPenalty(frequencyPenalty)
99+
.build();
91100

92-
this.modelName = ensureNotBlank(modelName, "modelName");
93-
this.temperature = temperature;
94-
this.topP = topP;
95-
this.stop = stop;
96-
this.maxTokens = maxTokens;
97-
this.presencePenalty = presencePenalty;
98-
this.frequencyPenalty = frequencyPenalty;
99101
this.seed = seed;
100102
this.user = user;
101103
this.toolChoice = toolChoice;
102104
this.parallelToolCalls = parallelToolCalls;
103-
this.maxRetries = getOrDefault(maxRetries, 3);
104-
this.listeners = getOrDefault(listeners, List.of());
105+
}
106+
107+
@Override
108+
public ChatRequestParameters defaultRequestParameters() {
109+
return defaultRequestParameters;
110+
}
111+
112+
@Override
113+
public List<ChatModelListener> listeners() {
114+
return listeners;
105115
}
106116

107117
@Override
@@ -110,14 +120,14 @@ public ChatResponse doChat(ChatRequest request) {
110120
ChatRequestParameters parameters = request.parameters();
111121
List<ToolSpecification> toolSpecifications = parameters.toolSpecifications();
112122
ChatCompletionRequest.Builder builder = ChatCompletionRequest.builder()
113-
.model(modelName)
123+
.model(parameters.modelName())
114124
.messages(toXinferenceMessages(messages))
115-
.temperature(temperature)
116-
.topP(topP)
117-
.stop(stop)
118-
.maxTokens(maxTokens)
119-
.presencePenalty(presencePenalty)
120-
.frequencyPenalty(frequencyPenalty)
125+
.temperature(parameters.temperature())
126+
.topP(parameters.topP())
127+
.stop(parameters.stopSequences())
128+
.maxTokens(parameters.maxOutputTokens())
129+
.presencePenalty(parameters.presencePenalty())
130+
.frequencyPenalty(parameters.frequencyPenalty())
121131
.user(user)
122132
.seed(seed)
123133
.toolChoice(toolChoice)

models/langchain4j-community-xinference/src/main/java/dev/langchain4j/community/model/xinference/XinferenceStreamingChatModel.java

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static dev.langchain4j.community.model.xinference.InternalXinferenceHelper.toTools;
44
import static dev.langchain4j.community.model.xinference.InternalXinferenceHelper.toXinferenceMessages;
5+
import static dev.langchain4j.internal.Utils.copy;
56
import static dev.langchain4j.internal.Utils.getOrDefault;
67
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
78
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
@@ -38,18 +39,15 @@ public class XinferenceStreamingChatModel implements StreamingChatModel {
3839
private static final Logger log = LoggerFactory.getLogger(XinferenceStreamingChatModel.class);
3940

4041
private final XinferenceClient client;
41-
private final String modelName;
42-
private final Double temperature;
43-
private final Double topP;
44-
private final List<String> stop;
45-
private final Integer maxTokens;
46-
private final Double presencePenalty;
47-
private final Double frequencyPenalty;
42+
private final List<ChatModelListener> listeners;
43+
private final ChatRequestParameters defaultRequestParameters;
44+
45+
/* TODO: support custom ChatRequestParameters */
46+
4847
private final Integer seed;
4948
private final String user;
5049
private final Object toolChoice;
5150
private final Boolean parallelToolCalls;
52-
private final List<ChatModelListener> listeners;
5351

5452
public XinferenceStreamingChatModel(
5553
String baseUrl,
@@ -72,6 +70,7 @@ public XinferenceStreamingChatModel(
7270
Map<String, String> customHeaders,
7371
List<ChatModelListener> listeners) {
7472
timeout = getOrDefault(timeout, Duration.ofSeconds(60));
73+
this.listeners = copy(listeners);
7574

7675
this.client = XinferenceClient.builder()
7776
.baseUrl(baseUrl)
@@ -85,19 +84,30 @@ public XinferenceStreamingChatModel(
8584
.logStreamingResponses(logResponses)
8685
.customHeaders(customHeaders)
8786
.build();
87+
this.defaultRequestParameters = ChatRequestParameters.builder()
88+
.modelName(ensureNotBlank(modelName, "modelName"))
89+
.temperature(temperature)
90+
.topP(topP)
91+
.stopSequences(stop)
92+
.maxOutputTokens(maxTokens)
93+
.presencePenalty(presencePenalty)
94+
.frequencyPenalty(frequencyPenalty)
95+
.build();
8896

89-
this.modelName = ensureNotBlank(modelName, "modelName");
90-
this.temperature = temperature;
91-
this.topP = topP;
92-
this.stop = stop;
93-
this.maxTokens = maxTokens;
94-
this.presencePenalty = presencePenalty;
95-
this.frequencyPenalty = frequencyPenalty;
9697
this.seed = seed;
9798
this.user = user;
9899
this.toolChoice = toolChoice;
99100
this.parallelToolCalls = parallelToolCalls;
100-
this.listeners = getOrDefault(listeners, List.of());
101+
}
102+
103+
@Override
104+
public ChatRequestParameters defaultRequestParameters() {
105+
return defaultRequestParameters;
106+
}
107+
108+
@Override
109+
public List<ChatModelListener> listeners() {
110+
return listeners;
101111
}
102112

103113
@Override
@@ -107,14 +117,14 @@ public void doChat(ChatRequest request, StreamingChatResponseHandler handler) {
107117
List<ToolSpecification> toolSpecifications = parameters.toolSpecifications();
108118
ChatCompletionRequest.Builder builder = ChatCompletionRequest.builder().stream(true)
109119
.streamOptions(StreamOptions.of(true))
110-
.model(modelName)
120+
.model(parameters.modelName())
111121
.messages(toXinferenceMessages(messages))
112-
.temperature(temperature)
113-
.topP(topP)
114-
.stop(stop)
115-
.maxTokens(maxTokens)
116-
.presencePenalty(presencePenalty)
117-
.frequencyPenalty(frequencyPenalty)
122+
.temperature(parameters.temperature())
123+
.topP(parameters.topP())
124+
.stop(parameters.stopSequences())
125+
.maxTokens(parameters.maxOutputTokens())
126+
.presencePenalty(parameters.presencePenalty())
127+
.frequencyPenalty(parameters.frequencyPenalty())
118128
.user(user)
119129
.seed(seed)
120130
.toolChoice(toolChoice)

0 commit comments

Comments
 (0)