diff --git a/lib/llm/src/backend.rs b/lib/llm/src/backend.rs index a4b75d7aba..4e8db043ff 100644 --- a/lib/llm/src/backend.rs +++ b/lib/llm/src/backend.rs @@ -94,12 +94,13 @@ impl Backend { stream: ManyOut, prompt_token_ids: &[TokenIdType], stop_conditions: StopConditions, + skip_special_tokens: bool, ) -> anyhow::Result { let Some(tokenizer) = self.tokenizer.as_ref() else { anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer"); }; let decoder = Decoder::new( - tokenizer.decode_stream(prompt_token_ids, false), + tokenizer.decode_stream(prompt_token_ids, skip_special_tokens), stop_conditions, ); @@ -129,10 +130,18 @@ impl let prompt_token_ids = request.token_ids.clone(); + // TODO: Consider updating default to true to match behavior of other frameworks + let skip_special_tokens = request.output_options.skip_special_tokens.unwrap_or(false); + let next_stream = next.generate(request).await?; let context = next_stream.context(); - let state = self.decoder(next_stream, &prompt_token_ids, stop_conditions)?; + let state = self.decoder( + next_stream, + &prompt_token_ids, + stop_conditions, + skip_special_tokens, + )?; let processed_stream = stream::unfold(state, |mut state| async move { match state.stream.next().await { diff --git a/lib/llm/src/protocols/common.rs b/lib/llm/src/protocols/common.rs index 59d0cb002b..76d35f9fa8 100644 --- a/lib/llm/src/protocols/common.rs +++ b/lib/llm/src/protocols/common.rs @@ -473,8 +473,6 @@ pub struct OutputOptions { pub prompt_logprobs: Option, /// Whether to skip special tokens in the output. - /// spaces_between_special_tokens: Whether to add spaces between special - /// tokens in the output. Defaults to True. pub skip_special_tokens: Option, /// If true, the Context object will contain the prompt that was pass to diff --git a/lib/llm/src/protocols/openai/chat_completions.rs b/lib/llm/src/protocols/openai/chat_completions.rs index 2dc2212952..26e7e5abfb 100644 --- a/lib/llm/src/protocols/openai/chat_completions.rs +++ b/lib/llm/src/protocols/openai/chat_completions.rs @@ -198,6 +198,10 @@ impl CommonExtProvider for NvCreateChatCompletionRequest { fn get_include_stop_str_in_output(&self) -> Option { self.common.include_stop_str_in_output } + + fn get_skip_special_tokens(&self) -> Option { + self.common.skip_special_tokens + } } /// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`, @@ -263,7 +267,7 @@ impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest { } fn get_skip_special_tokens(&self) -> Option { - None + CommonExtProvider::get_skip_special_tokens(self) } fn get_formatted_prompt(&self) -> Option { @@ -316,3 +320,71 @@ impl ValidateRequest for NvCreateChatCompletionRequest { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocols::common::OutputOptionsProvider; + use serde_json::json; + + #[test] + fn test_skip_special_tokens_none() { + let json_str = json!({ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }); + + let request: NvCreateChatCompletionRequest = + serde_json::from_value(json_str).expect("Failed to deserialize request"); + + assert_eq!(request.common.skip_special_tokens, None); + + let output_options = request + .extract_output_options() + .expect("Failed to extract output options"); + + assert_eq!(output_options.skip_special_tokens, None); + } + + #[test] + fn test_skip_special_tokens_true_propagates() { + let json_str = json!({ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "skip_special_tokens": true + }); + + let request: NvCreateChatCompletionRequest = + serde_json::from_value(json_str).expect("Failed to deserialize request"); + + let output_options = request + .extract_output_options() + .expect("Failed to extract output options"); + + assert_eq!(output_options.skip_special_tokens, Some(true)); + } + + #[test] + fn test_skip_special_tokens_false_propagates() { + let json_str = json!({ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "skip_special_tokens": false + }); + + let request: NvCreateChatCompletionRequest = + serde_json::from_value(json_str).expect("Failed to deserialize request"); + + let output_options = request + .extract_output_options() + .expect("Failed to extract output options"); + + assert_eq!(output_options.skip_special_tokens, Some(false)); + } +} diff --git a/lib/llm/src/protocols/openai/common_ext.rs b/lib/llm/src/protocols/openai/common_ext.rs index fdcf6db3bf..a77f765ae6 100644 --- a/lib/llm/src/protocols/openai/common_ext.rs +++ b/lib/llm/src/protocols/openai/common_ext.rs @@ -72,6 +72,14 @@ pub struct CommonExt { #[builder(default, setter(strip_option))] #[allow(unused)] // Not used pub guided_whitespace_pattern: Option, + + /// Whether to skip special tokens in the decoded output. + /// When true, special tokens (like EOS, BOS, PAD) are removed from the output text. + /// When false, special tokens are included in the output text. + /// Defaults to false if not specified. + #[serde(default, skip_serializing_if = "Option::is_none")] + #[builder(default, setter(strip_option))] + pub skip_special_tokens: Option, } impl CommonExt { @@ -99,6 +107,9 @@ pub trait CommonExtProvider { fn get_min_p(&self) -> Option; fn get_repetition_penalty(&self) -> Option; fn get_include_stop_str_in_output(&self) -> Option; + + /// Output Options + fn get_skip_special_tokens(&self) -> Option; } #[cfg(test)] @@ -120,6 +131,7 @@ mod tests { assert_eq!(common_ext.guided_choice, None); assert_eq!(common_ext.guided_decoding_backend, None); assert_eq!(common_ext.include_stop_str_in_output, None); + assert_eq!(common_ext.skip_special_tokens, None); } #[test] @@ -135,6 +147,7 @@ mod tests { .guided_grammar("grammar".to_string()) .guided_choice(vec!["choice1".to_string(), "choice2".to_string()]) .guided_decoding_backend("backend".to_string()) + .skip_special_tokens(false) .build() .unwrap(); @@ -157,6 +170,7 @@ mod tests { common_ext.guided_decoding_backend, Some("backend".to_string()) ); + assert_eq!(common_ext.skip_special_tokens, Some(false)); } #[test] @@ -190,6 +204,7 @@ mod tests { guided_choice: None, guided_decoding_backend: None, guided_whitespace_pattern: None, + skip_special_tokens: None, }; assert!(common_ext.validate().is_ok()); } @@ -219,4 +234,52 @@ mod tests { assert_eq!(common_ext.include_stop_str_in_output, None); assert!(common_ext.validate().is_ok()); } + + #[test] + fn test_skip_special_tokens_field() { + // Test that skip_special_tokens can be set and retrieved + let common_ext = CommonExt::builder() + .skip_special_tokens(true) + .build() + .unwrap(); + + assert_eq!(common_ext.skip_special_tokens, Some(true)); + + let common_ext = CommonExt::builder() + .skip_special_tokens(false) + .build() + .unwrap(); + + assert_eq!(common_ext.skip_special_tokens, Some(false)); + } + + #[test] + fn test_skip_special_tokens_serialization() { + // Test that skip_special_tokens can be serialized and deserialized + let common_ext = CommonExt::builder() + .skip_special_tokens(true) + .build() + .unwrap(); + + let json = serde_json::to_string(&common_ext).unwrap(); + let deserialized: CommonExt = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.skip_special_tokens, Some(true)); + + // Test with false value + let common_ext = CommonExt::builder() + .skip_special_tokens(false) + .build() + .unwrap(); + + let json = serde_json::to_string(&common_ext).unwrap(); + let deserialized: CommonExt = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.skip_special_tokens, Some(false)); + + // Test that None is not serialized (skip_serializing_if = "Option::is_none") + let common_ext = CommonExt::builder().build().unwrap(); + let json = serde_json::to_string(&common_ext).unwrap(); + assert!(!json.contains("skip_special_tokens")); + } } diff --git a/lib/llm/src/protocols/openai/completions.rs b/lib/llm/src/protocols/openai/completions.rs index 45d5885b6c..ac703d2517 100644 --- a/lib/llm/src/protocols/openai/completions.rs +++ b/lib/llm/src/protocols/openai/completions.rs @@ -189,6 +189,10 @@ impl CommonExtProvider for NvCreateCompletionRequest { fn get_include_stop_str_in_output(&self) -> Option { self.common.include_stop_str_in_output } + + fn get_skip_special_tokens(&self) -> Option { + self.common.skip_special_tokens + } } impl OpenAIStopConditionsProvider for NvCreateCompletionRequest { @@ -364,7 +368,7 @@ impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest { } fn get_skip_special_tokens(&self) -> Option { - None + CommonExtProvider::get_skip_special_tokens(self) } fn get_formatted_prompt(&self) -> Option { @@ -407,3 +411,65 @@ impl ValidateRequest for NvCreateCompletionRequest { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocols::common::OutputOptionsProvider; + use serde_json::json; + + #[test] + fn test_skip_special_tokens_none() { + let json_str = json!({ + "model": "test-model", + "prompt": "Hello, world!" + }); + + let request: NvCreateCompletionRequest = + serde_json::from_value(json_str).expect("Failed to deserialize request"); + + assert_eq!(request.common.skip_special_tokens, None); + + let output_options = request + .extract_output_options() + .expect("Failed to extract output options"); + + assert_eq!(output_options.skip_special_tokens, None); + } + + #[test] + fn test_skip_special_tokens_true_propagates() { + let json_str = json!({ + "model": "test-model", + "prompt": "Hello, world!", + "skip_special_tokens": true + }); + + let request: NvCreateCompletionRequest = + serde_json::from_value(json_str).expect("Failed to deserialize request"); + + let output_options = request + .extract_output_options() + .expect("Failed to extract output options"); + + assert_eq!(output_options.skip_special_tokens, Some(true)); + } + + #[test] + fn test_skip_special_tokens_false_propagates() { + let json_str = json!({ + "model": "test-model", + "prompt": "Hello, world!", + "skip_special_tokens": false + }); + + let request: NvCreateCompletionRequest = + serde_json::from_value(json_str).expect("Failed to deserialize request"); + + let output_options = request + .extract_output_options() + .expect("Failed to extract output options"); + + assert_eq!(output_options.skip_special_tokens, Some(false)); + } +} diff --git a/lib/llm/tests/tokenizers.rs b/lib/llm/tests/tokenizers.rs index 9cba788926..478b206948 100644 --- a/lib/llm/tests/tokenizers.rs +++ b/lib/llm/tests/tokenizers.rs @@ -178,3 +178,78 @@ fn test_long_sequence_incremental_decode_with_prefill() { assert_eq!(output.trim(), output_text.to_string()); } } + +#[test] +fn test_decode_with_skip_special_tokens() { + let tokenizer = HuggingFaceTokenizer::from_file(TINYLLAMA_TOKENIZER_PATH) + .expect("Failed to load remote HuggingFace tokenizer"); + + // Create a sequence with special tokens: + // (token_id: 1) + "Hello world" + (token_id: 2) + let text = "Hello world"; + let encoding = tokenizer.encode(text).expect("Failed to encode text"); + let mut token_ids = vec![1]; // + token_ids.extend(encoding.token_ids()); + token_ids.push(2); // + + // Decode with skip_special_tokens = false (should keep special tokens) + let decoded_with_special = tokenizer + .decode(&token_ids, false) + .expect("Failed to decode with skip_special_tokens=false"); + + // Decode with skip_special_tokens = true (should remove special tokens) + let decoded_without_special = tokenizer + .decode(&token_ids, true) + .expect("Failed to decode with skip_special_tokens=true"); + + // Print the decoded values for visibility + println!("Token IDs: {:?}", token_ids); + println!( + "Decoded WITH special tokens (skip=false): {:?}", + decoded_with_special + ); + println!( + "Decoded WITHOUT special tokens (skip=true): {:?}", + decoded_without_special + ); + + // Verify that the version with special tokens contains the special token markers + assert!( + decoded_with_special.contains(""), + "Expected decoded text with skip_special_tokens=false to contain '', but got: {}", + decoded_with_special + ); + assert!( + decoded_with_special.contains(""), + "Expected decoded text with skip_special_tokens=false to contain '', but got: {}", + decoded_with_special + ); + + // Verify that the version without special tokens does NOT contain the special token markers + assert!( + !decoded_without_special.contains(""), + "Expected decoded text with skip_special_tokens=true to NOT contain '', but got: {}", + decoded_without_special + ); + assert!( + !decoded_without_special.contains(""), + "Expected decoded text with skip_special_tokens=true to NOT contain '', but got: {}", + decoded_without_special + ); + + // The text content should be present in both versions + assert!( + decoded_with_special.contains(text), + "Expected decoded text with skip_special_tokens=false to contain the original text" + ); + assert!( + decoded_without_special.contains(text), + "Expected decoded text with skip_special_tokens=true to contain the original text" + ); + + // The version without special tokens should be shorter + assert!( + decoded_without_special.len() < decoded_with_special.len(), + "Expected decoded text with skip_special_tokens=true to be shorter than with skip_special_tokens=false" + ); +}