1818
1919import java .net .MalformedURLException ;
2020import java .net .URI ;
21+ import java .nio .charset .StandardCharsets ;
2122import java .util .List ;
2223import java .util .Map ;
2324
2930import org .mockito .Mock ;
3031import org .mockito .Mockito ;
3132import org .mockito .junit .jupiter .MockitoExtension ;
33+ import org .springframework .core .io .ByteArrayResource ;
34+ import org .springframework .util .MimeType ;
3235import reactor .core .publisher .Flux ;
3336
3437import org .springframework .ai .chat .messages .SystemMessage ;
@@ -126,11 +129,8 @@ public void userMessageWithMediaType() throws MalformedURLException {
126129 given (this .openAiApi .chatCompletionEntity (this .pomptCaptor .capture (), this .headersCaptor .capture ()))
127130 .willReturn (Mockito .mock (ResponseEntity .class ));
128131
129- URI mediaUri = URI .create ("http://test" );
130- this .chatModel .call (new Prompt (List .of (UserMessage .builder ()
131- .text ("test message" )
132- .media (List .of (Media .builder ().mimeType (MimeTypeUtils .IMAGE_JPEG ).data (mediaUri ).build ()))
133- .build ())));
132+ this .chatModel
133+ .call (new Prompt (List .of (UserMessage .builder ().text ("test message" ).media (this .buildMediaList ()).build ())));
134134
135135 validateComplexContent (this .pomptCaptor .getValue ());
136136 }
@@ -141,11 +141,10 @@ public void streamUserMessageWithMediaType() throws MalformedURLException {
141141 given (this .openAiApi .chatCompletionStream (this .pomptCaptor .capture (), this .headersCaptor .capture ()))
142142 .willReturn (this .fluxResponse );
143143
144- URI mediaUrl = URI .create ("http://test" );
145- this .chatModel .stream (new Prompt (List .of (UserMessage .builder ()
146- .text ("test message" )
147- .media (List .of (Media .builder ().mimeType (MimeTypeUtils .IMAGE_JPEG ).data (mediaUrl ).build ()))
148- .build ()))).subscribe ();
144+ this .chatModel
145+ .stream (new Prompt (
146+ List .of (UserMessage .builder ().text ("test message" ).media (this .buildMediaList ()).build ())))
147+ .subscribe ();
149148
150149 validateComplexContent (this .pomptCaptor .getValue ());
151150 }
@@ -161,16 +160,40 @@ private void validateComplexContent(ChatCompletionRequest chatCompletionRequest)
161160 @ SuppressWarnings ({ "unused" , "unchecked" })
162161 List <Map <String , Object >> mediaContents = (List <Map <String , Object >>) userMessage .rawContent ();
163162
164- assertThat (mediaContents ).hasSize (2 );
163+ assertThat (mediaContents ).hasSize (3 );
165164
165+ // Assert text content
166166 Map <String , Object > textContent = mediaContents .get (0 );
167167 assertThat (textContent .get ("type" )).isEqualTo ("text" );
168168 assertThat (textContent .get ("text" )).isEqualTo ("test message" );
169169
170+ // Assert image content
170171 Map <String , Object > imageContent = mediaContents .get (1 );
171172
172173 assertThat (imageContent .get ("type" )).isEqualTo ("image_url" );
173174 assertThat (imageContent ).containsKey ("image_url" );
175+
176+ // Assert file content
177+ Map <String , Object > fileContent = mediaContents .get (2 );
178+ assertThat (fileContent .get ("type" )).isEqualTo ("file" );
179+ assertThat (fileContent ).containsKey ("file" );
180+ assertThat (fileContent .get ("file" )).isInstanceOf (Map .class );
181+
182+ Map <String , Object > fileMap = (Map <String , Object >) fileContent .get ("file" );
183+ assertThat (fileMap .get ("file_data" )).isEqualTo ("data:application/pdf;base64,JVBERi0xLjc=" );
184+ }
185+
186+ private List <Media > buildMediaList () {
187+ URI imageUri = URI .create ("http://test" );
188+ Media imageMedia = Media .builder ().mimeType (MimeTypeUtils .IMAGE_JPEG ).data (imageUri ).build ();
189+
190+ byte [] pdfData = "%PDF-1.7" .getBytes (StandardCharsets .UTF_8 );
191+ Media pdfMedia = Media .builder ()
192+ .mimeType (MimeType .valueOf ("application/pdf" ))
193+ .data (new ByteArrayResource (pdfData ))
194+ .build ();
195+
196+ return List .of (imageMedia , pdfMedia );
174197 }
175198
176199}
0 commit comments