Skip to content

Commit 3fd0ab3

Browse files
authored
fix: multi-turn bug in should_add_generation_prompt (#4168)
1 parent 7750ed1 commit 3fd0ab3

File tree

1 file changed

+54
-8
lines changed
  • lib/llm/src/preprocessor/prompt/template

1 file changed

+54
-8
lines changed

lib/llm/src/preprocessor/prompt/template/oai.rs

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,17 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
163163
}
164164

165165
fn should_add_generation_prompt(&self) -> bool {
166-
if let Some(last) = self.inner.messages.last() {
167-
matches!(
168-
last,
169-
dynamo_async_openai::types::ChatCompletionRequestMessage::User(_)
170-
)
171-
} else {
172-
true
173-
}
166+
// Only add generation prompt if the last message was not assistant (default to true when no last message)
167+
self.inner
168+
.messages
169+
.last()
170+
.map(|last| {
171+
!matches!(
172+
last,
173+
dynamo_async_openai::types::ChatCompletionRequestMessage::Assistant(_)
174+
)
175+
})
176+
.unwrap_or(true)
174177
}
175178

176179
fn extract_text(&self) -> Option<TextInput> {
@@ -294,6 +297,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
294297
#[cfg(test)]
295298
mod tests {
296299
use super::*;
300+
use dynamo_async_openai::types::ChatCompletionRequestMessage as Msg;
297301

298302
#[test]
299303
fn test_may_be_fix_tool_schema_missing_type_and_properties() {
@@ -700,4 +704,46 @@ NORMAL MODE
700704
assert!(messages[0]["content"].is_array());
701705
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
702706
}
707+
708+
fn user() -> Msg {
709+
Msg::User(Default::default())
710+
}
711+
fn asst() -> Msg {
712+
Msg::Assistant(Default::default())
713+
}
714+
fn tool() -> Msg {
715+
Msg::Tool(Default::default())
716+
}
717+
718+
fn dummy_state(messages: Vec<Msg>) -> NvCreateChatCompletionRequest {
719+
let json = serde_json::json!({
720+
"model": "test-model",
721+
"messages": messages
722+
});
723+
serde_json::from_value(json).unwrap()
724+
}
725+
726+
#[test]
727+
fn add_after_user() {
728+
let s = dummy_state(vec![user()]);
729+
assert!(s.should_add_generation_prompt());
730+
}
731+
732+
#[test]
733+
fn add_after_tool() {
734+
let s = dummy_state(vec![tool()]);
735+
assert!(s.should_add_generation_prompt());
736+
}
737+
738+
#[test]
739+
fn no_after_assistant() {
740+
let s = dummy_state(vec![asst()]);
741+
assert!(!s.should_add_generation_prompt());
742+
}
743+
744+
#[test]
745+
fn add_when_empty() {
746+
let s = dummy_state(vec![]);
747+
assert!(s.should_add_generation_prompt());
748+
}
703749
}

0 commit comments

Comments
 (0)