66import static dev .langchain4j .community .model .xinference .InternalXinferenceHelper .toXinferenceMessages ;
77import static dev .langchain4j .community .model .xinference .InternalXinferenceHelper .tokenUsageFrom ;
88import static dev .langchain4j .internal .RetryUtils .withRetry ;
9+ import static dev .langchain4j .internal .Utils .copy ;
910import static dev .langchain4j .internal .Utils .getOrDefault ;
1011import static dev .langchain4j .internal .ValidationUtils .ensureNotBlank ;
1112import static dev .langchain4j .spi .ServiceHelper .loadFactories ;
@@ -39,19 +40,16 @@ public class XinferenceChatModel implements ChatModel {
3940 private static final Logger log = LoggerFactory .getLogger (XinferenceChatModel .class );
4041
4142 private final XinferenceClient client ;
42- private final String modelName ;
43- private final Double temperature ;
44- private final Double topP ;
45- private final List <String > stop ;
46- private final Integer maxTokens ;
47- private final Double presencePenalty ;
48- private final Double frequencyPenalty ;
43+ private final Integer maxRetries ;
44+ private final List <ChatModelListener > listeners ;
45+ private final ChatRequestParameters defaultRequestParameters ;
46+
47+ /* TODO: support custom ChatRequestParameters */
48+
4949 private final Integer seed ;
5050 private final String user ;
5151 private final Object toolChoice ;
5252 private final Boolean parallelToolCalls ;
53- private final Integer maxRetries ;
54- private final List <ChatModelListener > listeners ;
5553
5654 public XinferenceChatModel (
5755 String baseUrl ,
@@ -75,6 +73,8 @@ public XinferenceChatModel(
7573 Map <String , String > customHeaders ,
7674 List <ChatModelListener > listeners ) {
7775 timeout = getOrDefault (timeout , Duration .ofSeconds (60 ));
76+ this .maxRetries = getOrDefault (maxRetries , 3 );
77+ this .listeners = copy (listeners );
7878
7979 this .client = XinferenceClient .builder ()
8080 .baseUrl (baseUrl )
@@ -88,20 +88,30 @@ public XinferenceChatModel(
8888 .logResponses (logResponses )
8989 .customHeaders (customHeaders )
9090 .build ();
91+ this .defaultRequestParameters = ChatRequestParameters .builder ()
92+ .modelName (ensureNotBlank (modelName , "modelName" ))
93+ .temperature (temperature )
94+ .topP (topP )
95+ .stopSequences (stop )
96+ .maxOutputTokens (maxTokens )
97+ .presencePenalty (presencePenalty )
98+ .frequencyPenalty (frequencyPenalty )
99+ .build ();
91100
92- this .modelName = ensureNotBlank (modelName , "modelName" );
93- this .temperature = temperature ;
94- this .topP = topP ;
95- this .stop = stop ;
96- this .maxTokens = maxTokens ;
97- this .presencePenalty = presencePenalty ;
98- this .frequencyPenalty = frequencyPenalty ;
99101 this .seed = seed ;
100102 this .user = user ;
101103 this .toolChoice = toolChoice ;
102104 this .parallelToolCalls = parallelToolCalls ;
103- this .maxRetries = getOrDefault (maxRetries , 3 );
104- this .listeners = getOrDefault (listeners , List .of ());
105+ }
106+
107+ @ Override
108+ public ChatRequestParameters defaultRequestParameters () {
109+ return defaultRequestParameters ;
110+ }
111+
112+ @ Override
113+ public List <ChatModelListener > listeners () {
114+ return listeners ;
105115 }
106116
107117 @ Override
@@ -110,14 +120,14 @@ public ChatResponse doChat(ChatRequest request) {
110120 ChatRequestParameters parameters = request .parameters ();
111121 List <ToolSpecification > toolSpecifications = parameters .toolSpecifications ();
112122 ChatCompletionRequest .Builder builder = ChatCompletionRequest .builder ()
113- .model (modelName )
123+ .model (parameters . modelName () )
114124 .messages (toXinferenceMessages (messages ))
115- .temperature (temperature )
116- .topP (topP )
117- .stop (stop )
118- .maxTokens (maxTokens )
119- .presencePenalty (presencePenalty )
120- .frequencyPenalty (frequencyPenalty )
125+ .temperature (parameters . temperature () )
126+ .topP (parameters . topP () )
127+ .stop (parameters . stopSequences () )
128+ .maxTokens (parameters . maxOutputTokens () )
129+ .presencePenalty (parameters . presencePenalty () )
130+ .frequencyPenalty (parameters . frequencyPenalty () )
121131 .user (user )
122132 .seed (seed )
123133 .toolChoice (toolChoice )
0 commit comments