@@ -194,17 +194,14 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
194194 }
195195
196196 fn should_add_generation_prompt ( & self ) -> bool {
197- // Only add generation prompt if the last message was not assistant (default to true when no last message)
198- self . inner
199- . messages
200- . last ( )
201- . map ( |last| {
202- !matches ! (
203- last,
204- dynamo_async_openai:: types:: ChatCompletionRequestMessage :: Assistant ( _)
205- )
206- } )
207- . unwrap_or ( true )
197+ if let Some ( last) = self . inner . messages . last ( ) {
198+ matches ! (
199+ last,
200+ dynamo_async_openai:: types:: ChatCompletionRequestMessage :: User ( _)
201+ )
202+ } else {
203+ true
204+ }
208205 }
209206
210207 fn extract_text ( & self ) -> Option < TextInput > {
@@ -334,6 +331,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
334331mod tests {
335332 use super :: * ;
336333 use dynamo_async_openai:: types:: ChatCompletionRequestMessage as Msg ;
334+ use minijinja:: { Environment , context} ;
337335
338336 #[ test]
339337 fn test_may_be_fix_tool_schema_missing_type_and_properties ( ) {
@@ -741,52 +739,8 @@ NORMAL MODE
741739 assert_eq ! ( messages[ 0 ] [ "content" ] . as_array( ) . unwrap( ) . len( ) , 3 ) ;
742740 }
743741
744- fn user ( ) -> Msg {
745- Msg :: User ( Default :: default ( ) )
746- }
747- fn asst ( ) -> Msg {
748- Msg :: Assistant ( Default :: default ( ) )
749- }
750- fn tool ( ) -> Msg {
751- Msg :: Tool ( Default :: default ( ) )
752- }
753-
754- fn dummy_state ( messages : Vec < Msg > ) -> NvCreateChatCompletionRequest {
755- let json = serde_json:: json!( {
756- "model" : "test-model" ,
757- "messages" : messages
758- } ) ;
759- serde_json:: from_value ( json) . unwrap ( )
760- }
761-
762- #[ test]
763- fn add_after_user ( ) {
764- let s = dummy_state ( vec ! [ user( ) ] ) ;
765- assert ! ( s. should_add_generation_prompt( ) ) ;
766- }
767-
768- #[ test]
769- fn add_after_tool ( ) {
770- let s = dummy_state ( vec ! [ tool( ) ] ) ;
771- assert ! ( s. should_add_generation_prompt( ) ) ;
772- }
773-
774- #[ test]
775- fn no_after_assistant ( ) {
776- let s = dummy_state ( vec ! [ asst( ) ] ) ;
777- assert ! ( !s. should_add_generation_prompt( ) ) ;
778- }
779-
780- #[ test]
781- fn add_when_empty ( ) {
782- let s = dummy_state ( vec ! [ ] ) ;
783- assert ! ( s. should_add_generation_prompt( ) ) ;
784- }
785-
786742 #[ test]
787743 fn test_normalize_tool_arguments_tojson ( ) {
788- use minijinja:: { Environment , context} ;
789-
790744 let tmpl = r#"{{ messages[0].tool_calls[0].function.arguments | tojson }}"# ;
791745
792746 // Message with tool_calls containing JSON string arguments
@@ -821,8 +775,6 @@ NORMAL MODE
821775
822776 #[ test]
823777 fn test_normalize_tool_arguments_items_loop ( ) {
824- use minijinja:: { Environment , context} ;
825-
826778 let tmpl = r#"{% for k, v in messages[0].tool_calls[0].function.arguments|items %}{{k}}={{v}};{% endfor %}"# ;
827779
828780 let mut messages = serde_json:: Value :: Array ( vec ! [ serde_json:: json!( {
@@ -888,6 +840,5 @@ NORMAL MODE
888840 messages[ 0 ] [ "tool_calls" ] [ 0 ] [ "function" ] [ "arguments" ] ,
889841 serde_json:: Value :: String ( "not valid json at all" . to_string( ) )
890842 ) ;
891-
892843 }
893844}
0 commit comments