Skip to content

Commit fe8e5a8

Browse files
committed
chore: rebase
1 parent 9e2775b commit fe8e5a8

File tree

1 file changed

+53
-8
lines changed
  • lib/llm/src/preprocessor/prompt/template

1 file changed

+53
-8
lines changed

lib/llm/src/preprocessor/prompt/template/oai.rs

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,17 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
194194
}
195195

196196
fn should_add_generation_prompt(&self) -> bool {
197-
if let Some(last) = self.inner.messages.last() {
198-
matches!(
199-
last,
200-
dynamo_async_openai::types::ChatCompletionRequestMessage::User(_)
201-
)
202-
} else {
203-
true
204-
}
197+
// Only add generation prompt if the last message was not assistant (default to true when no last message)
198+
self.inner
199+
.messages
200+
.last()
201+
.map(|last| {
202+
!matches!(
203+
last,
204+
dynamo_async_openai::types::ChatCompletionRequestMessage::Assistant(_)
205+
)
206+
})
207+
.unwrap_or(true)
205208
}
206209

207210
fn extract_text(&self) -> Option<TextInput> {
@@ -885,4 +888,46 @@ NORMAL MODE
885888
"https://example.com/vid.mp4"
886889
);
887890
}
891+
892+
fn user() -> Msg {
893+
Msg::User(Default::default())
894+
}
895+
fn asst() -> Msg {
896+
Msg::Assistant(Default::default())
897+
}
898+
fn tool() -> Msg {
899+
Msg::Tool(Default::default())
900+
}
901+
902+
fn dummy_state(messages: Vec<Msg>) -> NvCreateChatCompletionRequest {
903+
let json = serde_json::json!({
904+
"model": "test-model",
905+
"messages": messages
906+
});
907+
serde_json::from_value(json).unwrap()
908+
}
909+
910+
#[test]
911+
fn add_after_user() {
912+
let s = dummy_state(vec![user()]);
913+
assert!(s.should_add_generation_prompt());
914+
}
915+
916+
#[test]
917+
fn add_after_tool() {
918+
let s = dummy_state(vec![tool()]);
919+
assert!(s.should_add_generation_prompt());
920+
}
921+
922+
#[test]
923+
fn no_after_assistant() {
924+
let s = dummy_state(vec![asst()]);
925+
assert!(!s.should_add_generation_prompt());
926+
}
927+
928+
#[test]
929+
fn add_when_empty() {
930+
let s = dummy_state(vec![]);
931+
assert!(s.should_add_generation_prompt());
932+
}
888933
}

0 commit comments

Comments
 (0)