Skip to content

Commit ab81d57

Browse files
alxkmsobychacko
authored andcommitted
test: Improve unit test coverage for various components
- Add comprehensive validation tests for OpenAI runtime hints registration - Add tests for MistralAI retry logic and AOT native image support - Add comprehensive test coverage for Azure OpenAI chat and embedding options - Remove shouldHandleVeryLargeInputList to reduce load during build - Add comprehensive edge case tests for ChatModel, ChatResponse, and ListOutputConverter - Add whenMultipleGenerationsWithToolCallsThenReturnTrue test - Add comprehensive edge case tests for OllamaOptions and OllamaRuntimeHints - Add additional test coverage for edge cases and boundary conditions in Ollama components - Add verifyHintsRegistration test - Add edge case tests for SimpleVectorStore document handling and filtering - Add validation and edge case tests for ObservabilityHelper and AiOperationMetadata - Add builtWithValidValuesThenFieldsAreAccessible test Signed-off-by: Oleksandr Klymenko <[email protected]>
1 parent 568ca96 commit ab81d57

File tree

15 files changed

+619
-1
lines changed

15 files changed

+619
-1
lines changed

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.azure.openai;
1818

19+
import java.util.ArrayList;
1920
import java.util.Arrays;
2021
import java.util.Collections;
2122
import java.util.List;
@@ -158,4 +159,45 @@ public void shouldHandleConcurrentRequests() {
158159
assertThat(options2.getUser()).isEqualTo("USER2");
159160
}
160161

162+
@Test
163+
public void shouldHandleEmptyStringInputs() {
164+
List<String> inputsWithEmpty = Arrays.asList("", "Valid text", "", "Another valid text");
165+
var requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(inputsWithEmpty, null));
166+
167+
assertThat(requestOptions.getInput()).hasSize(4);
168+
assertThat(requestOptions.getInput()).containsExactly("", "Valid text", "", "Another valid text");
169+
}
170+
171+
@Test
172+
public void shouldHandleDifferentClientConfigurations() {
173+
var clientWithDifferentDefaults = new AzureOpenAiEmbeddingModel(mockClient, MetadataMode.EMBED,
174+
AzureOpenAiEmbeddingOptions.builder().deploymentName("DIFFERENT_DEFAULT").build());
175+
176+
var requestOptions = clientWithDifferentDefaults
177+
.toEmbeddingOptions(new EmbeddingRequest(List.of("Test content"), null));
178+
179+
assertThat(requestOptions.getModel()).isEqualTo("DIFFERENT_DEFAULT");
180+
assertThat(requestOptions.getUser()).isNull(); // No default user set
181+
}
182+
183+
@Test
184+
public void shouldHandleWhitespaceOnlyInputs() {
185+
List<String> whitespaceInputs = Arrays.asList(" ", "\t\t", "\n\n", " valid text ");
186+
var requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(whitespaceInputs, null));
187+
188+
assertThat(requestOptions.getInput()).hasSize(4);
189+
assertThat(requestOptions.getInput()).containsExactlyElementsOf(whitespaceInputs);
190+
}
191+
192+
@Test
193+
public void shouldValidateInputListIsNotModified() {
194+
List<String> originalInputs = Arrays.asList("Input 1", "Input 2", "Input 3");
195+
List<String> inputsCopy = new ArrayList<>(originalInputs);
196+
197+
client.toEmbeddingOptions(new EmbeddingRequest(inputsCopy, null));
198+
199+
// Verify original list wasn't modified
200+
assertThat(inputsCopy).isEqualTo(originalInputs);
201+
}
202+
161203
}

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,4 +195,49 @@ void testDefaultValues() {
195195
assertThat(options.getModel()).isNull();
196196
}
197197

198+
@Test
199+
void testModelAndDeploymentNameRelationship() {
200+
AzureOpenAiChatOptions options = new AzureOpenAiChatOptions();
201+
202+
// Test setting deployment name first
203+
options.setDeploymentName("deployment-1");
204+
assertThat(options.getDeploymentName()).isEqualTo("deployment-1");
205+
assertThat(options.getModel()).isEqualTo("deployment-1");
206+
207+
// Test setting model overwrites deployment name
208+
options.setModel("model-1");
209+
assertThat(options.getDeploymentName()).isEqualTo("model-1");
210+
assertThat(options.getModel()).isEqualTo("model-1");
211+
}
212+
213+
@Test
214+
void testResponseFormatVariations() {
215+
// Test with JSON response format
216+
AzureOpenAiResponseFormat jsonFormat = AzureOpenAiResponseFormat.builder()
217+
.type(AzureOpenAiResponseFormat.Type.JSON_OBJECT)
218+
.build();
219+
220+
AzureOpenAiChatOptions options = AzureOpenAiChatOptions.builder().responseFormat(jsonFormat).build();
221+
222+
assertThat(options.getResponseFormat()).isEqualTo(jsonFormat);
223+
assertThat(options.getResponseFormat().getType()).isEqualTo(AzureOpenAiResponseFormat.Type.JSON_OBJECT);
224+
}
225+
226+
@Test
227+
void testEnhancementsConfiguration() {
228+
AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration();
229+
AzureChatOCREnhancementConfiguration ocrConfig = new AzureChatOCREnhancementConfiguration(false);
230+
AzureChatGroundingEnhancementConfiguration groundingConfig = new AzureChatGroundingEnhancementConfiguration(
231+
false);
232+
233+
enhancements.setOcr(ocrConfig);
234+
enhancements.setGrounding(groundingConfig);
235+
236+
AzureOpenAiChatOptions options = AzureOpenAiChatOptions.builder().enhancements(enhancements).build();
237+
238+
assertThat(options.getEnhancements()).isEqualTo(enhancements);
239+
assertThat(options.getEnhancements().getOcr()).isEqualTo(ocrConfig);
240+
assertThat(options.getEnhancements().getGrounding()).isEqualTo(groundingConfig);
241+
}
242+
198243
}

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,19 @@ public void mistralAiEmbeddingNonTransientError() {
178178
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)));
179179
}
180180

181+
@Test
182+
public void mistralAiChatMixedTransientAndNonTransientErrors() {
183+
given(this.mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
184+
.willThrow(new TransientAiException("Transient Error"))
185+
.willThrow(new RuntimeException("Non Transient Error"));
186+
187+
// Should fail immediately on non-transient error, no further retries
188+
assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text")));
189+
190+
// Should have 1 retry attempt before hitting non-transient error
191+
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2);
192+
}
193+
181194
private static class TestRetryListener implements RetryListener {
182195

183196
int onErrorRetryCount = 0;

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,60 @@ void verifyPackageScanningWorks() {
118118
assertThat(jsonAnnotatedClasses.size()).isGreaterThan(0);
119119
}
120120

121+
@Test
122+
void verifyAllCriticalApiClassesAreRegistered() {
123+
RuntimeHints runtimeHints = new RuntimeHints();
124+
MistralAiRuntimeHints mistralAiRuntimeHints = new MistralAiRuntimeHints();
125+
mistralAiRuntimeHints.registerHints(runtimeHints, null);
126+
127+
Set<TypeReference> registeredTypes = new HashSet<>();
128+
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
129+
130+
// Ensure critical API classes are registered for GraalVM native image reflection
131+
String[] criticalClasses = { "MistralAiApi$ChatCompletionRequest", "MistralAiApi$ChatCompletionMessage",
132+
"MistralAiApi$EmbeddingRequest", "MistralAiApi$EmbeddingList", "MistralAiApi$Usage" };
133+
134+
for (String className : criticalClasses) {
135+
assertThat(registeredTypes.stream()
136+
.anyMatch(tr -> tr.getName().contains(className.replace("$", "."))
137+
|| tr.getName().contains(className.replace("$", "$"))))
138+
.as("Critical class %s should be registered", className)
139+
.isTrue();
140+
}
141+
}
142+
143+
@Test
144+
void verifyEnumTypesAreRegistered() {
145+
RuntimeHints runtimeHints = new RuntimeHints();
146+
MistralAiRuntimeHints mistralAiRuntimeHints = new MistralAiRuntimeHints();
147+
mistralAiRuntimeHints.registerHints(runtimeHints, null);
148+
149+
Set<TypeReference> registeredTypes = new HashSet<>();
150+
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
151+
152+
// Enums are critical for JSON deserialization in native images
153+
assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.ChatModel.class)))
154+
.as("ChatModel enum should be registered")
155+
.isTrue();
156+
157+
assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.EmbeddingModel.class)))
158+
.as("EmbeddingModel enum should be registered")
159+
.isTrue();
160+
}
161+
162+
@Test
163+
void verifyReflectionHintsIncludeConstructors() {
164+
RuntimeHints runtimeHints = new RuntimeHints();
165+
MistralAiRuntimeHints mistralAiRuntimeHints = new MistralAiRuntimeHints();
166+
mistralAiRuntimeHints.registerHints(runtimeHints, null);
167+
168+
// Verify that reflection hints include constructor access
169+
boolean hasConstructorHints = runtimeHints.reflection()
170+
.typeHints()
171+
.anyMatch(typeHint -> typeHint.constructors().findAny().isPresent() || typeHint.getMemberCategories()
172+
.contains(org.springframework.aot.hint.MemberCategory.INVOKE_DECLARED_CONSTRUCTORS));
173+
174+
assertThat(hasConstructorHints).as("Should register constructor hints for JSON deserialization").isTrue();
175+
}
176+
121177
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,46 @@ void verifyNestedClassHintsAreRegistered() {
150150
assertThat(nestedClassCount).isGreaterThan(0);
151151
}
152152

153+
@Test
154+
void verifyEmbeddingRelatedClassesAreRegistered() {
155+
RuntimeHints runtimeHints = new RuntimeHints();
156+
OllamaRuntimeHints ollamaRuntimeHints = new OllamaRuntimeHints();
157+
ollamaRuntimeHints.registerHints(runtimeHints, null);
158+
159+
Set<TypeReference> registeredTypes = new HashSet<>();
160+
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
161+
162+
// Verify embedding-related classes are registered for reflection
163+
assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.EmbeddingsRequest.class))).isTrue();
164+
assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.EmbeddingsResponse.class))).isTrue();
165+
166+
// Count classes related to embedding functionality
167+
long embeddingClassCount = registeredTypes.stream()
168+
.filter(typeRef -> typeRef.getName().toLowerCase().contains("embedding"))
169+
.count();
170+
assertThat(embeddingClassCount).isGreaterThan(0);
171+
}
172+
173+
@Test
174+
void verifyHintsRegistrationWithCustomClassLoader() {
175+
RuntimeHints runtimeHints = new RuntimeHints();
176+
OllamaRuntimeHints ollamaRuntimeHints = new OllamaRuntimeHints();
177+
178+
// Create a custom class loader
179+
ClassLoader customClassLoader = Thread.currentThread().getContextClassLoader();
180+
181+
// Should work with custom class loader
182+
org.assertj.core.api.Assertions
183+
.assertThatCode(() -> ollamaRuntimeHints.registerHints(runtimeHints, customClassLoader))
184+
.doesNotThrowAnyException();
185+
186+
// Verify hints are still registered properly
187+
Set<TypeReference> registeredTypes = new HashSet<>();
188+
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
189+
190+
assertThat(registeredTypes.size()).isGreaterThan(0);
191+
assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.ChatRequest.class))).isTrue();
192+
assertThat(registeredTypes.contains(TypeReference.of(OllamaOptions.class))).isTrue();
193+
}
194+
153195
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,4 +204,42 @@ public void testDeprecatedMethods() {
204204
assertThat(options.getToolNames()).containsExactly("function1");
205205
}
206206

207+
@Test
208+
public void testEmptyOptions() {
209+
var options = OllamaOptions.builder().build();
210+
211+
var optionsMap = options.toMap();
212+
assertThat(optionsMap).isEmpty();
213+
214+
// Verify all getters return null/empty
215+
assertThat(options.getModel()).isNull();
216+
assertThat(options.getTemperature()).isNull();
217+
assertThat(options.getTopK()).isNull();
218+
assertThat(options.getToolNames()).isEmpty();
219+
assertThat(options.getToolContext()).isEmpty();
220+
}
221+
222+
@Test
223+
public void testNullValuesNotIncludedInMap() {
224+
var options = OllamaOptions.builder().model("llama2").temperature(null).topK(null).stop(null).build();
225+
226+
var optionsMap = options.toMap();
227+
assertThat(optionsMap).containsEntry("model", "llama2");
228+
assertThat(optionsMap).doesNotContainKey("temperature");
229+
assertThat(optionsMap).doesNotContainKey("top_k");
230+
assertThat(optionsMap).doesNotContainKey("stop");
231+
}
232+
233+
@Test
234+
public void testZeroValuesIncludedInMap() {
235+
var options = OllamaOptions.builder().temperature(0.0).topK(0).mainGPU(0).numGPU(0).seed(0).build();
236+
237+
var optionsMap = options.toMap();
238+
assertThat(optionsMap).containsEntry("temperature", 0.0);
239+
assertThat(optionsMap).containsEntry("top_k", 0);
240+
assertThat(optionsMap).containsEntry("main_gpu", 0);
241+
assertThat(optionsMap).containsEntry("num_gpu", 0);
242+
assertThat(optionsMap).containsEntry("seed", 0);
243+
}
244+
207245
}

0 commit comments

Comments
 (0)