diff --git a/lib/llm/src/preprocessor/prompt/template/oai.rs b/lib/llm/src/preprocessor/prompt/template/oai.rs index 23428e13da..78c6d35eec 100644 --- a/lib/llm/src/preprocessor/prompt/template/oai.rs +++ b/lib/llm/src/preprocessor/prompt/template/oai.rs @@ -163,14 +163,17 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest { } fn should_add_generation_prompt(&self) -> bool { - if let Some(last) = self.inner.messages.last() { - matches!( - last, - dynamo_async_openai::types::ChatCompletionRequestMessage::User(_) - ) - } else { - true - } + // Only add generation prompt if the last message was not assistant (default to true when no last message) + self.inner + .messages + .last() + .map(|last| { + !matches!( + last, + dynamo_async_openai::types::ChatCompletionRequestMessage::Assistant(_) + ) + }) + .unwrap_or(true) } fn extract_text(&self) -> Option { @@ -294,6 +297,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { #[cfg(test)] mod tests { use super::*; + use dynamo_async_openai::types::ChatCompletionRequestMessage as Msg; #[test] fn test_may_be_fix_tool_schema_missing_type_and_properties() { @@ -700,4 +704,46 @@ NORMAL MODE assert!(messages[0]["content"].is_array()); assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3); } + + fn user() -> Msg { + Msg::User(Default::default()) + } + fn asst() -> Msg { + Msg::Assistant(Default::default()) + } + fn tool() -> Msg { + Msg::Tool(Default::default()) + } + + fn dummy_state(messages: Vec) -> NvCreateChatCompletionRequest { + let json = serde_json::json!({ + "model": "test-model", + "messages": messages + }); + serde_json::from_value(json).unwrap() + } + + #[test] + fn add_after_user() { + let s = dummy_state(vec![user()]); + assert!(s.should_add_generation_prompt()); + } + + #[test] + fn add_after_tool() { + let s = dummy_state(vec![tool()]); + assert!(s.should_add_generation_prompt()); + } + + #[test] + fn no_after_assistant() { + let s = dummy_state(vec![asst()]); + assert!(!s.should_add_generation_prompt()); + } + + #[test] + fn add_when_empty() { + let s = dummy_state(vec![]); + assert!(s.should_add_generation_prompt()); + } }