Skip to content

Commit 90d3823

Browse files
habumaleijendary
authored andcommitted
Fix StabilityAI options merge precedence
This fix ensures proper inheritance of defaults while maintaining the correct precedence order for both generic and StabilityAI-specific options. Added tests to verify the merge behavior for runtime options, default options, and generic ImageOptions cases. Signed-off-by: leijendary <[email protected]>
1 parent 0aad21e commit 90d3823

File tree

3 files changed

+183
-8
lines changed

3 files changed

+183
-8
lines changed

models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,14 @@ private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse gener
115115

116116
/**
117117
* Merge runtime and default {@link ImageOptions} to compute the final options to use
118-
* in the request.
118+
* in the request. Protected access for testing purposes, though maybe useful for
119+
* future subclassing as options change.
119120
*/
120-
private StabilityAiImageOptions mergeOptions(ImageOptions runtimeOptions, StabilityAiImageOptions defaultOptions) {
121+
StabilityAiImageOptions mergeOptions(ImageOptions runtimeOptions, StabilityAiImageOptions defaultOptions) {
121122
if (runtimeOptions == null) {
122123
return defaultOptions;
123124
}
124-
125-
return StabilityAiImageOptions.builder()
125+
StabilityAiImageOptions.Builder builder = StabilityAiImageOptions.builder()
126126
// Handle portable image options
127127
.withModel(ModelOptionsUtils.mergeOption(runtimeOptions.getModel(), defaultOptions.getModel()))
128128
.withN(ModelOptionsUtils.mergeOption(runtimeOptions.getN(), defaultOptions.getN()))
@@ -131,14 +131,29 @@ private StabilityAiImageOptions mergeOptions(ImageOptions runtimeOptions, Stabil
131131
.withWidth(ModelOptionsUtils.mergeOption(runtimeOptions.getWidth(), defaultOptions.getWidth()))
132132
.withHeight(ModelOptionsUtils.mergeOption(runtimeOptions.getHeight(), defaultOptions.getHeight()))
133133
.withStylePreset(ModelOptionsUtils.mergeOption(runtimeOptions.getStyle(), defaultOptions.getStyle()))
134-
// Handle Stability AI specific image options
134+
// Always set the stability-specific defaults
135135
.withCfgScale(defaultOptions.getCfgScale())
136136
.withClipGuidancePreset(defaultOptions.getClipGuidancePreset())
137137
.withSampler(defaultOptions.getSampler())
138138
.withSeed(defaultOptions.getSeed())
139139
.withSteps(defaultOptions.getSteps())
140-
.withStylePreset(defaultOptions.getStylePreset())
141-
.build();
140+
.withStylePreset(defaultOptions.getStylePreset());
141+
if (runtimeOptions instanceof StabilityAiImageOptions) {
142+
StabilityAiImageOptions stabilityOptions = (StabilityAiImageOptions) runtimeOptions;
143+
// Handle Stability AI specific image options
144+
builder
145+
.withCfgScale(
146+
ModelOptionsUtils.mergeOption(stabilityOptions.getCfgScale(), defaultOptions.getCfgScale()))
147+
.withClipGuidancePreset(ModelOptionsUtils.mergeOption(stabilityOptions.getClipGuidancePreset(),
148+
defaultOptions.getClipGuidancePreset()))
149+
.withSampler(ModelOptionsUtils.mergeOption(stabilityOptions.getSampler(), defaultOptions.getSampler()))
150+
.withSeed(ModelOptionsUtils.mergeOption(stabilityOptions.getSeed(), defaultOptions.getSeed()))
151+
.withSteps(ModelOptionsUtils.mergeOption(stabilityOptions.getSteps(), defaultOptions.getSteps()))
152+
.withStylePreset(ModelOptionsUtils.mergeOption(stabilityOptions.getStylePreset(),
153+
defaultOptions.getStylePreset()));
154+
}
155+
156+
return builder.build();
142157
}
143158

144159
}

models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ public record GenerateImageRequest(@JsonProperty("text_prompts") List<TextPrompt
100100
@JsonProperty("cfg_scale") Float cfgScale, @JsonProperty("clip_guidance_preset") String clipGuidancePreset,
101101
@JsonProperty("sampler") String sampler, @JsonProperty("samples") Integer samples,
102102
@JsonProperty("seed") Long seed, @JsonProperty("steps") Integer steps,
103-
@JsonProperty("style_present") String stylePreset) {
103+
@JsonProperty("style_preset") String stylePreset) {
104104

105105
public static Builder builder() {
106106
return new Builder();
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.stabilityai;
17+
18+
import org.junit.jupiter.api.Test;
19+
import org.springframework.ai.image.ImageOptions;
20+
import org.springframework.ai.stabilityai.api.StabilityAiApi;
21+
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
22+
23+
import static org.assertj.core.api.Assertions.assertThat;
24+
import static org.mockito.Mockito.mock;
25+
26+
public class StabilityAiImageOptionsTests {
27+
28+
@Test
29+
void shouldPreferRuntimeOptionsOverDefaultOptions() {
30+
31+
StabilityAiApi stabilityAiApi = mock(StabilityAiApi.class);
32+
// Default options
33+
StabilityAiImageOptions defaultOptions = StabilityAiImageOptions.builder()
34+
.withN(1)
35+
.withModel("default-model")
36+
.withWidth(512)
37+
.withHeight(512)
38+
.withResponseFormat("image/png")
39+
.withCfgScale(7.0f)
40+
.withClipGuidancePreset("FAST_BLUE")
41+
.withSampler("DDIM")
42+
.withSeed(1234L)
43+
.withSteps(30)
44+
.withStylePreset("3d-model")
45+
.build();
46+
47+
// Runtime options with different values
48+
StabilityAiImageOptions runtimeOptions = StabilityAiImageOptions.builder()
49+
.withN(2)
50+
.withModel("runtime-model")
51+
.withWidth(1024)
52+
.withHeight(768)
53+
.withResponseFormat("application/json")
54+
.withCfgScale(14.0f)
55+
.withClipGuidancePreset("FAST_GREEN")
56+
.withSampler("DDPM")
57+
.withSeed(5678L)
58+
.withSteps(50)
59+
.withStylePreset("anime")
60+
.build();
61+
62+
StabilityAiImageModel imageModel = new StabilityAiImageModel(stabilityAiApi, defaultOptions);
63+
64+
StabilityAiImageOptions mergedOptions = imageModel.mergeOptions(runtimeOptions, defaultOptions);
65+
66+
assertThat(mergedOptions).satisfies(options -> {
67+
// Verify that all options match the runtime values, not the defaults
68+
assertThat(options.getN()).isEqualTo(2);
69+
assertThat(options.getModel()).isEqualTo("runtime-model");
70+
assertThat(options.getWidth()).isEqualTo(1024);
71+
assertThat(options.getHeight()).isEqualTo(768);
72+
assertThat(options.getResponseFormat()).isEqualTo("application/json");
73+
assertThat(options.getCfgScale()).isEqualTo(14.0f);
74+
assertThat(options.getClipGuidancePreset()).isEqualTo("FAST_GREEN");
75+
assertThat(options.getSampler()).isEqualTo("DDPM");
76+
assertThat(options.getSeed()).isEqualTo(5678L);
77+
assertThat(options.getSteps()).isEqualTo(50);
78+
assertThat(options.getStylePreset()).isEqualTo("anime");
79+
});
80+
}
81+
82+
@Test
83+
void shouldUseDefaultOptionsWhenRuntimeOptionsAreNull() {
84+
85+
StabilityAiApi stabilityAiApi = mock(StabilityAiApi.class);
86+
StabilityAiImageOptions defaultOptions = StabilityAiImageOptions.builder()
87+
.withN(1)
88+
.withModel("default-model")
89+
.withCfgScale(7.0f)
90+
.build();
91+
92+
StabilityAiImageModel imageModel = new StabilityAiImageModel(stabilityAiApi, defaultOptions);
93+
94+
StabilityAiImageOptions mergedOptions = imageModel.mergeOptions(null, defaultOptions);
95+
96+
assertThat(mergedOptions).satisfies(options -> {
97+
assertThat(options.getN()).isEqualTo(1);
98+
assertThat(options.getModel()).isEqualTo("default-model");
99+
assertThat(options.getCfgScale()).isEqualTo(7.0f);
100+
});
101+
}
102+
103+
@Test
104+
void shouldHandleGenericImageOptionsCorrectly() {
105+
106+
StabilityAiApi stabilityAiApi = mock(StabilityAiApi.class);
107+
StabilityAiImageOptions defaultOptions = StabilityAiImageOptions.builder()
108+
.withN(1)
109+
.withModel("default-model")
110+
.withWidth(512)
111+
.withCfgScale(7.0f)
112+
.build();
113+
114+
// Create a non-StabilityAi ImageOptions implementation
115+
ImageOptions genericOptions = new ImageOptions() {
116+
@Override
117+
public Integer getN() {
118+
return 2;
119+
}
120+
121+
@Override
122+
public String getModel() {
123+
return "generic-model";
124+
}
125+
126+
@Override
127+
public Integer getWidth() {
128+
return 1024;
129+
}
130+
131+
@Override
132+
public Integer getHeight() {
133+
return null;
134+
}
135+
136+
@Override
137+
public String getResponseFormat() {
138+
return null;
139+
}
140+
141+
@Override
142+
public String getStyle() {
143+
return null;
144+
}
145+
};
146+
147+
StabilityAiImageModel imageModel = new StabilityAiImageModel(stabilityAiApi, defaultOptions);
148+
149+
StabilityAiImageOptions mergedOptions = imageModel.mergeOptions(genericOptions, defaultOptions);
150+
151+
// Generic options should override defaults
152+
assertThat(mergedOptions.getN()).isEqualTo(2);
153+
assertThat(mergedOptions.getModel()).isEqualTo("generic-model");
154+
assertThat(mergedOptions.getWidth()).isEqualTo(1024);
155+
156+
// Stability-specific options should retain default values
157+
assertThat(mergedOptions.getCfgScale()).isEqualTo(7.0f);
158+
}
159+
160+
}

0 commit comments

Comments
 (0)