2121import java .util .List ;
2222import java .util .Map ;
2323import java .util .concurrent .ConcurrentHashMap ;
24+ import java .util .stream .Stream ;
2425
2526import io .micrometer .observation .Observation ;
2627import io .micrometer .observation .ObservationRegistry ;
3233import reactor .core .scheduler .Schedulers ;
3334
3435import org .springframework .ai .chat .messages .AssistantMessage ;
35- import org .springframework .ai .chat .messages .SystemMessage ;
36+ import org .springframework .ai .chat .messages .Message ;
3637import org .springframework .ai .chat .messages .ToolResponseMessage ;
3738import org .springframework .ai .chat .messages .UserMessage ;
3839import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
8485 * @author luocongqiu
8586 * @author Ilayaperumal Gopinathan
8687 * @author Alexandros Pappas
88+ * @author Nicolas Krier
8789 * @since 1.0.0
8890 */
8991public class MistralAiChatModel implements ChatModel {
@@ -429,52 +431,12 @@ Prompt buildRequestPrompt(Prompt prompt) {
429431 * Accessible for testing.
430432 */
431433 MistralAiApi .ChatCompletionRequest createRequest (Prompt prompt , boolean stream ) {
432- List <ChatCompletionMessage > chatCompletionMessages = prompt .getInstructions ().stream ().map (message -> {
433- if (message instanceof UserMessage userMessage ) {
434- Object content = message .getText ();
435-
436- if (!CollectionUtils .isEmpty (userMessage .getMedia ())) {
437- List <ChatCompletionMessage .MediaContent > contentList = new ArrayList <>(
438- List .of (new ChatCompletionMessage .MediaContent (message .getText ())));
439-
440- contentList .addAll (userMessage .getMedia ().stream ().map (this ::mapToMediaContent ).toList ());
441-
442- content = contentList ;
443- }
444-
445- return List
446- .of (new MistralAiApi .ChatCompletionMessage (content , MistralAiApi .ChatCompletionMessage .Role .USER ));
447- }
448- else if (message instanceof SystemMessage systemMessage ) {
449- return List .of (new MistralAiApi .ChatCompletionMessage (systemMessage .getText (),
450- MistralAiApi .ChatCompletionMessage .Role .SYSTEM ));
451- }
452- else if (message instanceof AssistantMessage assistantMessage ) {
453- List <ToolCall > toolCalls = null ;
454- if (!CollectionUtils .isEmpty (assistantMessage .getToolCalls ())) {
455- toolCalls = assistantMessage .getToolCalls ().stream ().map (toolCall -> {
456- var function = new ChatCompletionFunction (toolCall .name (), toolCall .arguments ());
457- return new ToolCall (toolCall .id (), toolCall .type (), function , null );
458- }).toList ();
459- }
460-
461- return List .of (new MistralAiApi .ChatCompletionMessage (assistantMessage .getText (),
462- MistralAiApi .ChatCompletionMessage .Role .ASSISTANT , null , toolCalls , null ));
463- }
464- else if (message instanceof ToolResponseMessage toolResponseMessage ) {
465- toolResponseMessage .getResponses ()
466- .forEach (response -> Assert .isTrue (response .id () != null , "ToolResponseMessage must have an id" ));
467-
468- return toolResponseMessage .getResponses ()
469- .stream ()
470- .map (toolResponse -> new MistralAiApi .ChatCompletionMessage (toolResponse .responseData (),
471- MistralAiApi .ChatCompletionMessage .Role .TOOL , toolResponse .name (), null , toolResponse .id ()))
472- .toList ();
473- }
474- else {
475- throw new IllegalStateException ("Unexpected message type: " + message );
476- }
477- }).flatMap (List ::stream ).toList ();
434+ // @formatter:off
435+ List <ChatCompletionMessage > chatCompletionMessages = prompt .getInstructions ()
436+ .stream ()
437+ .flatMap (this ::createChatCompletionMessages )
438+ .toList ();
439+ // @formatter:on
478440
479441 var request = new MistralAiApi .ChatCompletionRequest (chatCompletionMessages , stream );
480442
@@ -492,6 +454,78 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) {
492454 return request ;
493455 }
494456
457+ private Stream <ChatCompletionMessage > createChatCompletionMessages (Message message ) {
458+ switch (message .getMessageType ()) {
459+ case USER :
460+ return Stream .of (createUserChatCompletionMessage (message ));
461+ case SYSTEM :
462+ return Stream .of (createSystemChatCompletionMessage (message ));
463+ case ASSISTANT :
464+ return Stream .of (createAssistantChatCompletionMessage (message ));
465+ case TOOL :
466+ return createToolChatCompletionMessages (message );
467+ default :
468+ throw new IllegalStateException ("Unknown message type: " + message .getMessageType ());
469+ }
470+ }
471+
472+ private Stream <ChatCompletionMessage > createToolChatCompletionMessages (Message message ) {
473+ if (message instanceof ToolResponseMessage toolResponseMessage ) {
474+ var chatCompletionMessages = new ArrayList <ChatCompletionMessage >();
475+
476+ for (ToolResponseMessage .ToolResponse toolResponse : toolResponseMessage .getResponses ()) {
477+ Assert .isTrue (toolResponse .id () != null , "ToolResponseMessage.ToolResponse must have an id." );
478+ var chatCompletionMessage = new ChatCompletionMessage (toolResponse .responseData (),
479+ ChatCompletionMessage .Role .TOOL , toolResponse .name (), null , toolResponse .id ());
480+ chatCompletionMessages .add (chatCompletionMessage );
481+ }
482+
483+ return chatCompletionMessages .stream ();
484+ }
485+ else {
486+ throw new IllegalArgumentException ("Unsupported tool message class: " + message .getClass ().getName ());
487+ }
488+ }
489+
490+ private ChatCompletionMessage createAssistantChatCompletionMessage (Message message ) {
491+ if (message instanceof AssistantMessage assistantMessage ) {
492+ List <ToolCall > toolCalls = null ;
493+
494+ if (!CollectionUtils .isEmpty (assistantMessage .getToolCalls ())) {
495+ toolCalls = assistantMessage .getToolCalls ().stream ().map (this ::mapToolCall ).toList ();
496+ }
497+
498+ return new ChatCompletionMessage (assistantMessage .getText (), ChatCompletionMessage .Role .ASSISTANT , null ,
499+ toolCalls , null );
500+ }
501+ else {
502+ throw new IllegalArgumentException ("Unsupported assistant message class: " + message .getClass ().getName ());
503+ }
504+ }
505+
506+ private ChatCompletionMessage createSystemChatCompletionMessage (Message message ) {
507+ return new ChatCompletionMessage (message .getText (), ChatCompletionMessage .Role .SYSTEM );
508+ }
509+
510+ private ChatCompletionMessage createUserChatCompletionMessage (Message message ) {
511+ Object content = message .getText ();
512+
513+ if (message instanceof UserMessage userMessage && !CollectionUtils .isEmpty (userMessage .getMedia ())) {
514+ List <ChatCompletionMessage .MediaContent > contentList = new ArrayList <>(
515+ List .of (new ChatCompletionMessage .MediaContent (message .getText ())));
516+ contentList .addAll (userMessage .getMedia ().stream ().map (this ::mapToMediaContent ).toList ());
517+ content = contentList ;
518+ }
519+
520+ return new ChatCompletionMessage (content , ChatCompletionMessage .Role .USER );
521+ }
522+
523+ private ToolCall mapToolCall (AssistantMessage .ToolCall toolCall ) {
524+ var function = new ChatCompletionFunction (toolCall .name (), toolCall .arguments ());
525+
526+ return new ToolCall (toolCall .id (), toolCall .type (), function , null );
527+ }
528+
495529 private ChatCompletionMessage .MediaContent mapToMediaContent (Media media ) {
496530 return new ChatCompletionMessage .MediaContent (new ChatCompletionMessage .MediaContent .ImageUrl (
497531 this .fromMediaData (media .getMimeType (), media .getData ())));
0 commit comments