@@ -194,14 +194,17 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
194194 }
195195
196196 fn should_add_generation_prompt ( & self ) -> bool {
197- if let Some ( last) = self . inner . messages . last ( ) {
198- matches ! (
199- last,
200- dynamo_async_openai:: types:: ChatCompletionRequestMessage :: User ( _)
201- )
202- } else {
203- true
204- }
197+ // Only add generation prompt if the last message was not assistant (default to true when no last message)
198+ self . inner
199+ . messages
200+ . last ( )
201+ . map ( |last| {
202+ !matches ! (
203+ last,
204+ dynamo_async_openai:: types:: ChatCompletionRequestMessage :: Assistant ( _)
205+ )
206+ } )
207+ . unwrap_or ( true )
205208 }
206209
207210 fn extract_text ( & self ) -> Option < TextInput > {
@@ -885,4 +888,46 @@ NORMAL MODE
885888 "https://example.com/vid.mp4"
886889 ) ;
887890 }
891+
892+ fn user ( ) -> Msg {
893+ Msg :: User ( Default :: default ( ) )
894+ }
895+ fn asst ( ) -> Msg {
896+ Msg :: Assistant ( Default :: default ( ) )
897+ }
898+ fn tool ( ) -> Msg {
899+ Msg :: Tool ( Default :: default ( ) )
900+ }
901+
902+ fn dummy_state ( messages : Vec < Msg > ) -> NvCreateChatCompletionRequest {
903+ let json = serde_json:: json!( {
904+ "model" : "test-model" ,
905+ "messages" : messages
906+ } ) ;
907+ serde_json:: from_value ( json) . unwrap ( )
908+ }
909+
910+ #[ test]
911+ fn add_after_user ( ) {
912+ let s = dummy_state ( vec ! [ user( ) ] ) ;
913+ assert ! ( s. should_add_generation_prompt( ) ) ;
914+ }
915+
916+ #[ test]
917+ fn add_after_tool ( ) {
918+ let s = dummy_state ( vec ! [ tool( ) ] ) ;
919+ assert ! ( s. should_add_generation_prompt( ) ) ;
920+ }
921+
922+ #[ test]
923+ fn no_after_assistant ( ) {
924+ let s = dummy_state ( vec ! [ asst( ) ] ) ;
925+ assert ! ( !s. should_add_generation_prompt( ) ) ;
926+ }
927+
928+ #[ test]
929+ fn add_when_empty ( ) {
930+ let s = dummy_state ( vec ! [ ] ) ;
931+ assert ! ( s. should_add_generation_prompt( ) ) ;
932+ }
888933}
0 commit comments