Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions lib/llm/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,13 @@ impl Backend {
stream: ManyOut<ExecutionOutputStream>,
prompt_token_ids: &[TokenIdType],
stop_conditions: StopConditions,
skip_special_tokens: bool,
) -> anyhow::Result<DecoderUnfoldState> {
let Some(tokenizer) = self.tokenizer.as_ref() else {
anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
};
let decoder = Decoder::new(
tokenizer.decode_stream(prompt_token_ids, false),
tokenizer.decode_stream(prompt_token_ids, skip_special_tokens),
stop_conditions,
);

Expand Down Expand Up @@ -129,10 +130,18 @@ impl

let prompt_token_ids = request.token_ids.clone();

// TODO: Consider updating default to true to match behavior of other frameworks
let skip_special_tokens = request.output_options.skip_special_tokens.unwrap_or(false);

let next_stream = next.generate(request).await?;

let context = next_stream.context();
let state = self.decoder(next_stream, &prompt_token_ids, stop_conditions)?;
let state = self.decoder(
next_stream,
&prompt_token_ids,
stop_conditions,
skip_special_tokens,
)?;

let processed_stream = stream::unfold(state, |mut state| async move {
match state.stream.next().await {
Expand Down
2 changes: 0 additions & 2 deletions lib/llm/src/protocols/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,6 @@ pub struct OutputOptions {
pub prompt_logprobs: Option<u32>,

/// Whether to skip special tokens in the output.
/// spaces_between_special_tokens: Whether to add spaces between special
/// tokens in the output. Defaults to True.
pub skip_special_tokens: Option<bool>,

/// If true, the Context object will contain the prompt that was pass to
Expand Down
74 changes: 73 additions & 1 deletion lib/llm/src/protocols/openai/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}

fn get_skip_special_tokens(&self) -> Option<bool> {
self.common.skip_special_tokens
}
}

/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
Expand Down Expand Up @@ -263,7 +267,7 @@ impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest {
}

fn get_skip_special_tokens(&self) -> Option<bool> {
None
CommonExtProvider::get_skip_special_tokens(self)
}

fn get_formatted_prompt(&self) -> Option<bool> {
Expand Down Expand Up @@ -316,3 +320,71 @@ impl ValidateRequest for NvCreateChatCompletionRequest {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::OutputOptionsProvider;
use serde_json::json;

#[test]
fn test_skip_special_tokens_none() {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
});

let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

assert_eq!(request.common.skip_special_tokens, None);

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, None);
}

#[test]
fn test_skip_special_tokens_true_propagates() {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"skip_special_tokens": true
});

let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, Some(true));
}

#[test]
fn test_skip_special_tokens_false_propagates() {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"skip_special_tokens": false
});

let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, Some(false));
}
}
63 changes: 63 additions & 0 deletions lib/llm/src/protocols/openai/common_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ pub struct CommonExt {
#[builder(default, setter(strip_option))]
#[allow(unused)] // Not used
pub guided_whitespace_pattern: Option<String>,

/// Whether to skip special tokens in the decoded output.
/// When true, special tokens (like EOS, BOS, PAD) are removed from the output text.
/// When false, special tokens are included in the output text.
/// Defaults to true if not specified (matching vLLM/TensorRT-LLM behavior).
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub skip_special_tokens: Option<bool>,
Comment on lines +76 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix documentation inconsistency with actual default behavior.

The documentation states "Defaults to true if not specified (matching vLLM/TensorRT-LLM behavior)" but the actual implementation in lib/llm/src/backend.rs Line 134 uses unwrap_or(false), defaulting to false. This creates confusion about the expected behavior.

Update the documentation to reflect the current default:

-    /// Whether to skip special tokens in the decoded output.
-    /// When true, special tokens (like EOS, BOS, PAD) are removed from the output text.
-    /// When false, special tokens are included in the output text.
-    /// Defaults to true if not specified (matching vLLM/TensorRT-LLM behavior).
+    /// Whether to skip special tokens in the decoded output.
+    /// When true, special tokens (like EOS, BOS, PAD) are removed from the output text.
+    /// When false, special tokens are included in the output text.
+    /// Defaults to false if not specified. A future change may update this to true to match vLLM/TensorRT-LLM behavior.
🤖 Prompt for AI Agents
In lib/llm/src/protocols/openai/common_ext.rs around lines 76 to 82, the
docstring incorrectly states the default is true while the implementation uses
unwrap_or(false); update the documentation to state the default is false (i.e.,
"Defaults to false if not specified"), remove the parenthetical about matching
vLLM/TensorRT-LLM or change it to reflect the actual behavior, and ensure
serde/builder attributes remain unchanged.

}

impl CommonExt {
Expand Down Expand Up @@ -99,6 +107,9 @@ pub trait CommonExtProvider {
fn get_min_p(&self) -> Option<f32>;
fn get_repetition_penalty(&self) -> Option<f32>;
fn get_include_stop_str_in_output(&self) -> Option<bool>;

/// Output Options
fn get_skip_special_tokens(&self) -> Option<bool>;
}

#[cfg(test)]
Expand All @@ -120,6 +131,7 @@ mod tests {
assert_eq!(common_ext.guided_choice, None);
assert_eq!(common_ext.guided_decoding_backend, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
assert_eq!(common_ext.skip_special_tokens, None);
}

#[test]
Expand All @@ -135,6 +147,7 @@ mod tests {
.guided_grammar("grammar".to_string())
.guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
.guided_decoding_backend("backend".to_string())
.skip_special_tokens(false)
.build()
.unwrap();

Expand All @@ -157,6 +170,7 @@ mod tests {
common_ext.guided_decoding_backend,
Some("backend".to_string())
);
assert_eq!(common_ext.skip_special_tokens, Some(false));
}

#[test]
Expand Down Expand Up @@ -190,6 +204,7 @@ mod tests {
guided_choice: None,
guided_decoding_backend: None,
guided_whitespace_pattern: None,
skip_special_tokens: None,
};
assert!(common_ext.validate().is_ok());
}
Expand Down Expand Up @@ -219,4 +234,52 @@ mod tests {
assert_eq!(common_ext.include_stop_str_in_output, None);
assert!(common_ext.validate().is_ok());
}

#[test]
fn test_skip_special_tokens_field() {
// Test that skip_special_tokens can be set and retrieved
let common_ext = CommonExt::builder()
.skip_special_tokens(true)
.build()
.unwrap();

assert_eq!(common_ext.skip_special_tokens, Some(true));

let common_ext = CommonExt::builder()
.skip_special_tokens(false)
.build()
.unwrap();

assert_eq!(common_ext.skip_special_tokens, Some(false));
}

#[test]
fn test_skip_special_tokens_serialization() {
// Test that skip_special_tokens can be serialized and deserialized
let common_ext = CommonExt::builder()
.skip_special_tokens(true)
.build()
.unwrap();

let json = serde_json::to_string(&common_ext).unwrap();
let deserialized: CommonExt = serde_json::from_str(&json).unwrap();

assert_eq!(deserialized.skip_special_tokens, Some(true));

// Test with false value
let common_ext = CommonExt::builder()
.skip_special_tokens(false)
.build()
.unwrap();

let json = serde_json::to_string(&common_ext).unwrap();
let deserialized: CommonExt = serde_json::from_str(&json).unwrap();

assert_eq!(deserialized.skip_special_tokens, Some(false));

// Test that None is not serialized (skip_serializing_if = "Option::is_none")
let common_ext = CommonExt::builder().build().unwrap();
let json = serde_json::to_string(&common_ext).unwrap();
assert!(!json.contains("skip_special_tokens"));
}
}
68 changes: 67 additions & 1 deletion lib/llm/src/protocols/openai/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ impl CommonExtProvider for NvCreateCompletionRequest {
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}

fn get_skip_special_tokens(&self) -> Option<bool> {
self.common.skip_special_tokens
}
}

impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
Expand Down Expand Up @@ -364,7 +368,7 @@ impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
}

fn get_skip_special_tokens(&self) -> Option<bool> {
None
CommonExtProvider::get_skip_special_tokens(self)
}

fn get_formatted_prompt(&self) -> Option<bool> {
Expand Down Expand Up @@ -407,3 +411,65 @@ impl ValidateRequest for NvCreateCompletionRequest {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::OutputOptionsProvider;
use serde_json::json;

#[test]
fn test_skip_special_tokens_none() {
let json_str = json!({
"model": "test-model",
"prompt": "Hello, world!"
});

let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

assert_eq!(request.common.skip_special_tokens, None);

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, None);
}

#[test]
fn test_skip_special_tokens_true_propagates() {
let json_str = json!({
"model": "test-model",
"prompt": "Hello, world!",
"skip_special_tokens": true
});

let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, Some(true));
}

#[test]
fn test_skip_special_tokens_false_propagates() {
let json_str = json!({
"model": "test-model",
"prompt": "Hello, world!",
"skip_special_tokens": false
});

let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, Some(false));
}
}
Loading
Loading