Skip to content

Commit 6bfeee8

Browse files
committed
style: DRY principle for test import
Signed-off-by: Ryan Lempka <[email protected]>
1 parent 4e35a68 commit 6bfeee8

File tree

1 file changed

+9
-58
lines changed
  • lib/llm/src/preprocessor/prompt/template

1 file changed

+9
-58
lines changed

lib/llm/src/preprocessor/prompt/template/oai.rs

Lines changed: 9 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -194,17 +194,14 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
194194
}
195195

196196
fn should_add_generation_prompt(&self) -> bool {
197-
// Only add generation prompt if the last message was not assistant (default to true when no last message)
198-
self.inner
199-
.messages
200-
.last()
201-
.map(|last| {
202-
!matches!(
203-
last,
204-
dynamo_async_openai::types::ChatCompletionRequestMessage::Assistant(_)
205-
)
206-
})
207-
.unwrap_or(true)
197+
if let Some(last) = self.inner.messages.last() {
198+
matches!(
199+
last,
200+
dynamo_async_openai::types::ChatCompletionRequestMessage::User(_)
201+
)
202+
} else {
203+
true
204+
}
208205
}
209206

210207
fn extract_text(&self) -> Option<TextInput> {
@@ -334,6 +331,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
334331
mod tests {
335332
use super::*;
336333
use dynamo_async_openai::types::ChatCompletionRequestMessage as Msg;
334+
use minijinja::{Environment, context};
337335

338336
#[test]
339337
fn test_may_be_fix_tool_schema_missing_type_and_properties() {
@@ -741,52 +739,8 @@ NORMAL MODE
741739
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
742740
}
743741

744-
fn user() -> Msg {
745-
Msg::User(Default::default())
746-
}
747-
fn asst() -> Msg {
748-
Msg::Assistant(Default::default())
749-
}
750-
fn tool() -> Msg {
751-
Msg::Tool(Default::default())
752-
}
753-
754-
fn dummy_state(messages: Vec<Msg>) -> NvCreateChatCompletionRequest {
755-
let json = serde_json::json!({
756-
"model": "test-model",
757-
"messages": messages
758-
});
759-
serde_json::from_value(json).unwrap()
760-
}
761-
762-
#[test]
763-
fn add_after_user() {
764-
let s = dummy_state(vec![user()]);
765-
assert!(s.should_add_generation_prompt());
766-
}
767-
768-
#[test]
769-
fn add_after_tool() {
770-
let s = dummy_state(vec![tool()]);
771-
assert!(s.should_add_generation_prompt());
772-
}
773-
774-
#[test]
775-
fn no_after_assistant() {
776-
let s = dummy_state(vec![asst()]);
777-
assert!(!s.should_add_generation_prompt());
778-
}
779-
780-
#[test]
781-
fn add_when_empty() {
782-
let s = dummy_state(vec![]);
783-
assert!(s.should_add_generation_prompt());
784-
}
785-
786742
#[test]
787743
fn test_normalize_tool_arguments_tojson() {
788-
use minijinja::{Environment, context};
789-
790744
let tmpl = r#"{{ messages[0].tool_calls[0].function.arguments | tojson }}"#;
791745

792746
// Message with tool_calls containing JSON string arguments
@@ -821,8 +775,6 @@ NORMAL MODE
821775

822776
#[test]
823777
fn test_normalize_tool_arguments_items_loop() {
824-
use minijinja::{Environment, context};
825-
826778
let tmpl = r#"{% for k, v in messages[0].tool_calls[0].function.arguments|items %}{{k}}={{v}};{% endfor %}"#;
827779

828780
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
@@ -888,6 +840,5 @@ NORMAL MODE
888840
messages[0]["tool_calls"][0]["function"]["arguments"],
889841
serde_json::Value::String("not valid json at all".to_string())
890842
);
891-
892843
}
893844
}

0 commit comments

Comments
 (0)