Skip to content

Commit 58b9775

Browse files
nicolaskrierWillam2004
authored andcommitted
Refactor request creation of Mistral AI chat model
- Added createChatCompletionMessages in MistralAiChatModel to map Spring AI Message objects to Mistral AI’s ChatCompletionMessage stream, based on MessageType. - Rename Mistral AI chat completion request unit tests - Stop using deprecated assistant message constructor - Added comprehensive tests for the new conversion method covering: - Standard Spring AI messages. - Custom message types. - Improved existing tests by: - Removing redundant @SpringBootTest annotation. - Aligning test class name with project conventions. Signed-off-by: Nicolas Krier <[email protected]> Signed-off-by: 家娃 <[email protected]>
1 parent 6259798 commit 58b9775

File tree

3 files changed

+401
-172
lines changed

3 files changed

+401
-172
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 81 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.List;
2222
import java.util.Map;
2323
import java.util.concurrent.ConcurrentHashMap;
24+
import java.util.stream.Stream;
2425

2526
import io.micrometer.observation.Observation;
2627
import io.micrometer.observation.ObservationRegistry;
@@ -32,7 +33,7 @@
3233
import reactor.core.scheduler.Schedulers;
3334

3435
import org.springframework.ai.chat.messages.AssistantMessage;
35-
import org.springframework.ai.chat.messages.SystemMessage;
36+
import org.springframework.ai.chat.messages.Message;
3637
import org.springframework.ai.chat.messages.ToolResponseMessage;
3738
import org.springframework.ai.chat.messages.UserMessage;
3839
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
@@ -84,6 +85,7 @@
8485
* @author luocongqiu
8586
* @author Ilayaperumal Gopinathan
8687
* @author Alexandros Pappas
88+
* @author Nicolas Krier
8789
* @since 1.0.0
8890
*/
8991
public class MistralAiChatModel implements ChatModel {
@@ -429,52 +431,12 @@ Prompt buildRequestPrompt(Prompt prompt) {
429431
* Accessible for testing.
430432
*/
431433
MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
432-
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
433-
if (message instanceof UserMessage userMessage) {
434-
Object content = message.getText();
435-
436-
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
437-
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
438-
List.of(new ChatCompletionMessage.MediaContent(message.getText())));
439-
440-
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
441-
442-
content = contentList;
443-
}
444-
445-
return List
446-
.of(new MistralAiApi.ChatCompletionMessage(content, MistralAiApi.ChatCompletionMessage.Role.USER));
447-
}
448-
else if (message instanceof SystemMessage systemMessage) {
449-
return List.of(new MistralAiApi.ChatCompletionMessage(systemMessage.getText(),
450-
MistralAiApi.ChatCompletionMessage.Role.SYSTEM));
451-
}
452-
else if (message instanceof AssistantMessage assistantMessage) {
453-
List<ToolCall> toolCalls = null;
454-
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
455-
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
456-
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
457-
return new ToolCall(toolCall.id(), toolCall.type(), function, null);
458-
}).toList();
459-
}
460-
461-
return List.of(new MistralAiApi.ChatCompletionMessage(assistantMessage.getText(),
462-
MistralAiApi.ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null));
463-
}
464-
else if (message instanceof ToolResponseMessage toolResponseMessage) {
465-
toolResponseMessage.getResponses()
466-
.forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"));
467-
468-
return toolResponseMessage.getResponses()
469-
.stream()
470-
.map(toolResponse -> new MistralAiApi.ChatCompletionMessage(toolResponse.responseData(),
471-
MistralAiApi.ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id()))
472-
.toList();
473-
}
474-
else {
475-
throw new IllegalStateException("Unexpected message type: " + message);
476-
}
477-
}).flatMap(List::stream).toList();
434+
// @formatter:off
435+
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
436+
.stream()
437+
.flatMap(this::createChatCompletionMessages)
438+
.toList();
439+
// @formatter:on
478440

479441
var request = new MistralAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
480442

@@ -492,6 +454,78 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) {
492454
return request;
493455
}
494456

457+
private Stream<ChatCompletionMessage> createChatCompletionMessages(Message message) {
458+
switch (message.getMessageType()) {
459+
case USER:
460+
return Stream.of(createUserChatCompletionMessage(message));
461+
case SYSTEM:
462+
return Stream.of(createSystemChatCompletionMessage(message));
463+
case ASSISTANT:
464+
return Stream.of(createAssistantChatCompletionMessage(message));
465+
case TOOL:
466+
return createToolChatCompletionMessages(message);
467+
default:
468+
throw new IllegalStateException("Unknown message type: " + message.getMessageType());
469+
}
470+
}
471+
472+
private Stream<ChatCompletionMessage> createToolChatCompletionMessages(Message message) {
473+
if (message instanceof ToolResponseMessage toolResponseMessage) {
474+
var chatCompletionMessages = new ArrayList<ChatCompletionMessage>();
475+
476+
for (ToolResponseMessage.ToolResponse toolResponse : toolResponseMessage.getResponses()) {
477+
Assert.isTrue(toolResponse.id() != null, "ToolResponseMessage.ToolResponse must have an id.");
478+
var chatCompletionMessage = new ChatCompletionMessage(toolResponse.responseData(),
479+
ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id());
480+
chatCompletionMessages.add(chatCompletionMessage);
481+
}
482+
483+
return chatCompletionMessages.stream();
484+
}
485+
else {
486+
throw new IllegalArgumentException("Unsupported tool message class: " + message.getClass().getName());
487+
}
488+
}
489+
490+
private ChatCompletionMessage createAssistantChatCompletionMessage(Message message) {
491+
if (message instanceof AssistantMessage assistantMessage) {
492+
List<ToolCall> toolCalls = null;
493+
494+
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
495+
toolCalls = assistantMessage.getToolCalls().stream().map(this::mapToolCall).toList();
496+
}
497+
498+
return new ChatCompletionMessage(assistantMessage.getText(), ChatCompletionMessage.Role.ASSISTANT, null,
499+
toolCalls, null);
500+
}
501+
else {
502+
throw new IllegalArgumentException("Unsupported assistant message class: " + message.getClass().getName());
503+
}
504+
}
505+
506+
private ChatCompletionMessage createSystemChatCompletionMessage(Message message) {
507+
return new ChatCompletionMessage(message.getText(), ChatCompletionMessage.Role.SYSTEM);
508+
}
509+
510+
private ChatCompletionMessage createUserChatCompletionMessage(Message message) {
511+
Object content = message.getText();
512+
513+
if (message instanceof UserMessage userMessage && !CollectionUtils.isEmpty(userMessage.getMedia())) {
514+
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
515+
List.of(new ChatCompletionMessage.MediaContent(message.getText())));
516+
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
517+
content = contentList;
518+
}
519+
520+
return new ChatCompletionMessage(content, ChatCompletionMessage.Role.USER);
521+
}
522+
523+
private ToolCall mapToolCall(AssistantMessage.ToolCall toolCall) {
524+
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
525+
526+
return new ToolCall(toolCall.id(), toolCall.type(), function, null);
527+
}
528+
495529
private ChatCompletionMessage.MediaContent mapToMediaContent(Media media) {
496530
return new ChatCompletionMessage.MediaContent(new ChatCompletionMessage.MediaContent.ImageUrl(
497531
this.fromMediaData(media.getMimeType(), media.getData())));

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java

Lines changed: 0 additions & 125 deletions
This file was deleted.

0 commit comments

Comments
 (0)