Skip to content

Commit 40b57ed

Browse files
markpollackleijendary
authored andcommitted
Simplify builder pattern for options
This change streamlines the builder implementation by removing generics that was complicating the implementation and providing hard to debug checkstyle warnings. It adopts a simpler, more direct builder pattern. Key changes: - Remove generic type parameters from builder interfaces - Switch to concrete builder implementations with direct field access - Make all collection getters return unmodifiable views - Ensure proper copy semantics in builders and options - Add comprehensive test coverage for builder behavior Signed-off-by: leijendary <[email protected]>
1 parent c315868 commit 40b57ed

File tree

12 files changed

+865
-267
lines changed

12 files changed

+865
-267
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.vertexai.gemini;
1818

1919
import java.util.ArrayList;
20+
import java.util.HashMap;
2021
import java.util.HashSet;
2122
import java.util.List;
2223
import java.util.Map;
@@ -28,6 +29,7 @@
2829
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2930
import com.fasterxml.jackson.annotation.JsonProperty;
3031

32+
import org.springframework.ai.chat.prompt.ChatOptions;
3133
import org.springframework.ai.model.function.FunctionCallback;
3234
import org.springframework.ai.model.function.FunctionCallingOptions;
3335
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel;
@@ -68,7 +70,7 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions {
6870
/**
6971
* Optional. If specified, top k sampling will be used.
7072
*/
71-
private @JsonProperty("topK") Float topK;
73+
private @JsonProperty("topK") Integer topK;
7274

7375
/**
7476
* Optional. The maximum number of tokens to generate.
@@ -183,16 +185,11 @@ public void setTopP(Double topP) {
183185

184186
@Override
185187
public Integer getTopK() {
186-
return (this.topK != null) ? this.topK.intValue() : null;
188+
return this.topK;
187189
}
188190

189-
public void setTopK(Float topK) {
190-
this.topK = topK;
191-
}
192-
193-
@JsonIgnore
194191
public void setTopK(Integer topK) {
195-
this.topK = (topK != null) ? topK.floatValue() : null;
192+
this.topK = topK;
196193
}
197194

198195
public Integer getCandidateCount() {
@@ -346,6 +343,67 @@ public VertexAiGeminiChatOptions copy() {
346343
return fromOptions(this);
347344
}
348345

346+
public FunctionCallingOptions merge(ChatOptions options) {
347+
VertexAiGeminiChatOptions.Builder builder = VertexAiGeminiChatOptions.builder();
348+
349+
// Merge chat-specific options
350+
builder.model(options.getModel() != null ? options.getModel() : this.getModel())
351+
.maxOutputTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.getMaxOutputTokens())
352+
.stopSequences(options.getStopSequences() != null ? options.getStopSequences() : this.getStopSequences())
353+
.temperature(options.getTemperature() != null ? options.getTemperature() : this.getTemperature())
354+
.topP(options.getTopP() != null ? options.getTopP() : this.getTopP())
355+
.topK(options.getTopK() != null ? options.getTopK() : this.getTopK());
356+
357+
// Try to get function-specific properties if options is a FunctionCallingOptions
358+
if (options instanceof FunctionCallingOptions functionOptions) {
359+
builder.proxyToolCalls(functionOptions.getProxyToolCalls() != null ? functionOptions.getProxyToolCalls()
360+
: this.proxyToolCalls);
361+
362+
Set<String> functions = new HashSet<>();
363+
if (this.functions != null) {
364+
functions.addAll(this.functions);
365+
}
366+
if (functionOptions.getFunctions() != null) {
367+
functions.addAll(functionOptions.getFunctions());
368+
}
369+
builder.functions(functions);
370+
371+
List<FunctionCallback> functionCallbacks = new ArrayList<>();
372+
if (this.functionCallbacks != null) {
373+
functionCallbacks.addAll(this.functionCallbacks);
374+
}
375+
if (functionOptions.getFunctionCallbacks() != null) {
376+
functionCallbacks.addAll(functionOptions.getFunctionCallbacks());
377+
}
378+
builder.functionCallbacks(functionCallbacks);
379+
380+
Map<String, Object> context = new HashMap<>();
381+
if (this.toolContext != null) {
382+
context.putAll(this.toolContext);
383+
}
384+
if (functionOptions.getToolContext() != null) {
385+
context.putAll(functionOptions.getToolContext());
386+
}
387+
builder.toolContext(context);
388+
}
389+
else {
390+
// If not a FunctionCallingOptions, preserve current function-specific
391+
// properties
392+
builder.proxyToolCalls(this.proxyToolCalls);
393+
builder.functions(this.functions != null ? new HashSet<>(this.functions) : null);
394+
builder.functionCallbacks(this.functionCallbacks != null ? new ArrayList<>(this.functionCallbacks) : null);
395+
builder.toolContext(this.toolContext != null ? new HashMap<>(this.toolContext) : null);
396+
}
397+
398+
// Preserve Vertex AI Gemini-specific properties
399+
builder.candidateCount(this.candidateCount)
400+
.responseMimeType(this.responseMimeType)
401+
.googleSearchRetrieval(this.googleSearchRetrieval)
402+
.safetySettings(this.safetySettings != null ? new ArrayList<>(this.safetySettings) : null);
403+
404+
return builder.build();
405+
}
406+
349407
public enum TransportType {
350408

351409
GRPC, REST
@@ -371,7 +429,7 @@ public Builder topP(Double topP) {
371429
return this;
372430
}
373431

374-
public Builder topK(Float topK) {
432+
public Builder topK(Integer topK) {
375433
this.options.setTopK(topK);
376434
return this;
377435
}
@@ -473,10 +531,10 @@ public Builder withTopP(Double topP) {
473531
}
474532

475533
/**
476-
* @deprecated use {@link #topK(Float)} instead.
534+
* @deprecated use {@link #topK(Integer)} instead.
477535
*/
478536
@Deprecated(forRemoval = true, since = "1.0.0-M5")
479-
public Builder withTopK(Float topK) {
537+
public Builder withTopK(Integer topK) {
480538
this.options.setTopK(topK);
481539
return this;
482540
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ public void createRequestWithGenerationConfigOptions() {
203203
.model("DEFAULT_MODEL")
204204
.temperature(66.6)
205205
.maxOutputTokens(100)
206-
.topK(10.0f)
206+
.topK(10)
207207
.topP(5.0)
208208
.stopSequences(List.of("stop1", "stop2"))
209209
.candidateCount(1)
@@ -218,7 +218,7 @@ public void createRequestWithGenerationConfigOptions() {
218218
assertThat(request.model().getModelName()).isEqualTo("DEFAULT_MODEL");
219219
assertThat(request.model().getGenerationConfig().getTemperature()).isEqualTo(66.6f);
220220
assertThat(request.model().getGenerationConfig().getMaxOutputTokens()).isEqualTo(100);
221-
assertThat(request.model().getGenerationConfig().getTopK()).isEqualTo(10.0f);
221+
assertThat(request.model().getGenerationConfig().getTopK()).isEqualTo(10);
222222
assertThat(request.model().getGenerationConfig().getTopP()).isEqualTo(5.0f);
223223
assertThat(request.model().getGenerationConfig().getCandidateCount()).isEqualTo(1);
224224
assertThat(request.model().getGenerationConfig().getStopSequences(0)).isEqualTo("stop1");

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.zhipuai;
1818

1919
import java.util.ArrayList;
20+
import java.util.HashMap;
2021
import java.util.HashSet;
2122
import java.util.List;
2223
import java.util.Map;
@@ -27,6 +28,7 @@
2728
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2829
import com.fasterxml.jackson.annotation.JsonProperty;
2930

31+
import org.springframework.ai.chat.prompt.ChatOptions;
3032
import org.springframework.ai.model.function.FunctionCallback;
3133
import org.springframework.ai.model.function.FunctionCallingOptions;
3234
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
@@ -438,6 +440,67 @@ public ZhiPuAiChatOptions copy() {
438440
return fromOptions(this);
439441
}
440442

443+
public FunctionCallingOptions merge(ChatOptions options) {
444+
ZhiPuAiChatOptions.Builder builder = ZhiPuAiChatOptions.builder();
445+
446+
// Merge chat-specific options
447+
builder.model(options.getModel() != null ? options.getModel() : this.getModel())
448+
.maxTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.getMaxTokens())
449+
.stop(options.getStopSequences() != null ? options.getStopSequences() : this.getStopSequences())
450+
.temperature(options.getTemperature() != null ? options.getTemperature() : this.getTemperature())
451+
.topP(options.getTopP() != null ? options.getTopP() : this.getTopP());
452+
453+
// Try to get function-specific properties if options is a FunctionCallingOptions
454+
if (options instanceof FunctionCallingOptions functionOptions) {
455+
builder.proxyToolCalls(functionOptions.getProxyToolCalls() != null ? functionOptions.getProxyToolCalls()
456+
: this.proxyToolCalls);
457+
458+
Set<String> functions = new HashSet<>();
459+
if (this.functions != null) {
460+
functions.addAll(this.functions);
461+
}
462+
if (functionOptions.getFunctions() != null) {
463+
functions.addAll(functionOptions.getFunctions());
464+
}
465+
builder.functions(functions);
466+
467+
List<FunctionCallback> functionCallbacks = new ArrayList<>();
468+
if (this.functionCallbacks != null) {
469+
functionCallbacks.addAll(this.functionCallbacks);
470+
}
471+
if (functionOptions.getFunctionCallbacks() != null) {
472+
functionCallbacks.addAll(functionOptions.getFunctionCallbacks());
473+
}
474+
builder.functionCallbacks(functionCallbacks);
475+
476+
Map<String, Object> context = new HashMap<>();
477+
if (this.toolContext != null) {
478+
context.putAll(this.toolContext);
479+
}
480+
if (functionOptions.getToolContext() != null) {
481+
context.putAll(functionOptions.getToolContext());
482+
}
483+
builder.toolContext(context);
484+
}
485+
else {
486+
// If not a FunctionCallingOptions, preserve current function-specific
487+
// properties
488+
builder.proxyToolCalls(this.proxyToolCalls);
489+
builder.functions(this.functions != null ? new HashSet<>(this.functions) : null);
490+
builder.functionCallbacks(this.functionCallbacks != null ? new ArrayList<>(this.functionCallbacks) : null);
491+
builder.toolContext(this.toolContext != null ? new HashMap<>(this.toolContext) : null);
492+
}
493+
494+
// Preserve ZhiPuAi-specific properties
495+
builder.tools(this.tools)
496+
.toolChoice(this.toolChoice)
497+
.user(this.user)
498+
.requestId(this.requestId)
499+
.doSample(this.doSample);
500+
501+
return builder.build();
502+
}
503+
441504
public static class Builder {
442505

443506
protected ZhiPuAiChatOptions options;

spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,77 +87,77 @@ public interface ChatOptions extends ModelOptions {
8787
* Returns a copy of this {@link ChatOptions}.
8888
* @return a copy of this {@link ChatOptions}
8989
*/
90-
ChatOptions copy();
90+
<T extends ChatOptions> T copy();
9191

9292
/**
9393
* Creates a new {@link ChatOptions.Builder} to create the default
9494
* {@link ChatOptions}.
9595
* @return Returns a new {@link ChatOptions.Builder}.
9696
*/
97-
static ChatOptions.Builder<? extends DefaultChatOptionsBuilder> builder() {
97+
static ChatOptions.Builder builder() {
9898
return new DefaultChatOptionsBuilder();
9999
}
100100

101101
/**
102102
* Builder for creating {@link ChatOptions} instance.
103103
*/
104-
interface Builder<B extends Builder<B>> {
104+
interface Builder {
105105

106106
/**
107107
* Builds with the model to use for the chat.
108108
* @param model
109109
* @return the builder
110110
*/
111-
B model(String model);
111+
Builder model(String model);
112112

113113
/**
114114
* Builds with the frequency penalty to use for the chat.
115115
* @param frequencyPenalty
116116
* @return the builder.
117117
*/
118-
B frequencyPenalty(Double frequencyPenalty);
118+
Builder frequencyPenalty(Double frequencyPenalty);
119119

120120
/**
121121
* Builds with the maximum number of tokens to use for the chat.
122122
* @param maxTokens
123123
* @return the builder.
124124
*/
125-
B maxTokens(Integer maxTokens);
125+
Builder maxTokens(Integer maxTokens);
126126

127127
/**
128128
* Builds with the presence penalty to use for the chat.
129129
* @param presencePenalty
130130
* @return the builder.
131131
*/
132-
B presencePenalty(Double presencePenalty);
132+
Builder presencePenalty(Double presencePenalty);
133133

134134
/**
135135
* Builds with the stop sequences to use for the chat.
136136
* @param stopSequences
137137
* @return the builder.
138138
*/
139-
B stopSequences(List<String> stopSequences);
139+
Builder stopSequences(List<String> stopSequences);
140140

141141
/**
142142
* Builds with the temperature to use for the chat.
143143
* @param temperature
144144
* @return the builder.
145145
*/
146-
B temperature(Double temperature);
146+
Builder temperature(Double temperature);
147147

148148
/**
149149
* Builds with the top K to use for the chat.
150150
* @param topK
151151
* @return the builder.
152152
*/
153-
B topK(Integer topK);
153+
Builder topK(Integer topK);
154154

155155
/**
156156
* Builds with the top P to use for the chat.
157157
* @param topP
158158
* @return the builder.
159159
*/
160-
B topP(Double topP);
160+
Builder topP(Double topP);
161161

162162
/**
163163
* Build the {@link ChatOptions}.

spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.ai.chat.prompt;
1818

19+
import java.util.ArrayList;
20+
import java.util.Collections;
1921
import java.util.List;
2022

2123
/**
@@ -77,7 +79,7 @@ public void setPresencePenalty(Double presencePenalty) {
7779

7880
@Override
7981
public List<String> getStopSequences() {
80-
return this.stopSequences;
82+
return this.stopSequences != null ? Collections.unmodifiableList(this.stopSequences) : null;
8183
}
8284

8385
public void setStopSequences(List<String> stopSequences) {
@@ -112,17 +114,18 @@ public void setTopP(Double topP) {
112114
}
113115

114116
@Override
115-
public ChatOptions copy() {
116-
return ChatOptions.builder()
117-
.model(this.model)
118-
.frequencyPenalty(this.frequencyPenalty)
119-
.maxTokens(this.maxTokens)
120-
.presencePenalty(this.presencePenalty)
121-
.stopSequences(this.stopSequences != null ? List.copyOf(this.stopSequences) : null)
122-
.temperature(this.temperature)
123-
.topK(this.topK)
124-
.topP(this.topP)
125-
.build();
117+
@SuppressWarnings("unchecked")
118+
public <T extends ChatOptions> T copy() {
119+
DefaultChatOptions copy = new DefaultChatOptions();
120+
copy.setModel(this.getModel());
121+
copy.setFrequencyPenalty(this.getFrequencyPenalty());
122+
copy.setMaxTokens(this.getMaxTokens());
123+
copy.setPresencePenalty(this.getPresencePenalty());
124+
copy.setStopSequences(this.getStopSequences() != null ? new ArrayList<>(this.getStopSequences()) : null);
125+
copy.setTemperature(this.getTemperature());
126+
copy.setTopK(this.getTopK());
127+
copy.setTopP(this.getTopP());
128+
return (T) copy;
126129
}
127130

128131
}

0 commit comments

Comments
 (0)