Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions lib/llm/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,13 @@ impl Backend {
stream: ManyOut<ExecutionOutputStream>,
prompt_token_ids: &[TokenIdType],
stop_conditions: StopConditions,
skip_special_tokens: bool,
) -> anyhow::Result<DecoderUnfoldState> {
let Some(tokenizer) = self.tokenizer.as_ref() else {
anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
};
let decoder = Decoder::new(
tokenizer.decode_stream(prompt_token_ids, false),
tokenizer.decode_stream(prompt_token_ids, skip_special_tokens),
stop_conditions,
);

Expand Down Expand Up @@ -129,10 +130,18 @@ impl

let prompt_token_ids = request.token_ids.clone();

// TODO: Consider updating default to true to match behavior of other frameworks
let skip_special_tokens = request.output_options.skip_special_tokens.unwrap_or(false);

let next_stream = next.generate(request).await?;

let context = next_stream.context();
let state = self.decoder(next_stream, &prompt_token_ids, stop_conditions)?;
let state = self.decoder(
next_stream,
&prompt_token_ids,
stop_conditions,
skip_special_tokens,
)?;

let processed_stream = stream::unfold(state, |mut state| async move {
match state.stream.next().await {
Expand Down
2 changes: 0 additions & 2 deletions lib/llm/src/protocols/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,6 @@ pub struct OutputOptions {
pub prompt_logprobs: Option<u32>,

/// Whether to skip special tokens in the output.
/// spaces_between_special_tokens: Whether to add spaces between special
/// tokens in the output. Defaults to True.
pub skip_special_tokens: Option<bool>,

/// If true, the Context object will contain the prompt that was pass to
Expand Down
74 changes: 73 additions & 1 deletion lib/llm/src/protocols/openai/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}

fn get_skip_special_tokens(&self) -> Option<bool> {
self.common.skip_special_tokens
}
}

/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
Expand Down Expand Up @@ -263,7 +267,7 @@ impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest {
}

fn get_skip_special_tokens(&self) -> Option<bool> {
None
CommonExtProvider::get_skip_special_tokens(self)
}

fn get_formatted_prompt(&self) -> Option<bool> {
Expand Down Expand Up @@ -316,3 +320,71 @@ impl ValidateRequest for NvCreateChatCompletionRequest {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::OutputOptionsProvider;
use serde_json::json;

#[test]
fn test_skip_special_tokens_none() {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
});

let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

assert_eq!(request.common.skip_special_tokens, None);

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, None);
}

#[test]
fn test_skip_special_tokens_true_propagates() {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"skip_special_tokens": true
});

let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, Some(true));
}

#[test]
fn test_skip_special_tokens_false_propagates() {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"skip_special_tokens": false
});

let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, Some(false));
}
}
63 changes: 63 additions & 0 deletions lib/llm/src/protocols/openai/common_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ pub struct CommonExt {
#[builder(default, setter(strip_option))]
#[allow(unused)] // Not used
pub guided_whitespace_pattern: Option<String>,

/// Whether to skip special tokens in the decoded output.
/// When true, special tokens (like EOS, BOS, PAD) are removed from the output text.
/// When false, special tokens are included in the output text.
/// Defaults to false if not specified.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub skip_special_tokens: Option<bool>,
}

impl CommonExt {
Expand Down Expand Up @@ -99,6 +107,9 @@ pub trait CommonExtProvider {
fn get_min_p(&self) -> Option<f32>;
fn get_repetition_penalty(&self) -> Option<f32>;
fn get_include_stop_str_in_output(&self) -> Option<bool>;

/// Output Options
fn get_skip_special_tokens(&self) -> Option<bool>;
}

#[cfg(test)]
Expand All @@ -120,6 +131,7 @@ mod tests {
assert_eq!(common_ext.guided_choice, None);
assert_eq!(common_ext.guided_decoding_backend, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
assert_eq!(common_ext.skip_special_tokens, None);
}

#[test]
Expand All @@ -135,6 +147,7 @@ mod tests {
.guided_grammar("grammar".to_string())
.guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
.guided_decoding_backend("backend".to_string())
.skip_special_tokens(false)
.build()
.unwrap();

Expand All @@ -157,6 +170,7 @@ mod tests {
common_ext.guided_decoding_backend,
Some("backend".to_string())
);
assert_eq!(common_ext.skip_special_tokens, Some(false));
}

#[test]
Expand Down Expand Up @@ -190,6 +204,7 @@ mod tests {
guided_choice: None,
guided_decoding_backend: None,
guided_whitespace_pattern: None,
skip_special_tokens: None,
};
assert!(common_ext.validate().is_ok());
}
Expand Down Expand Up @@ -219,4 +234,52 @@ mod tests {
assert_eq!(common_ext.include_stop_str_in_output, None);
assert!(common_ext.validate().is_ok());
}

#[test]
fn test_skip_special_tokens_field() {
// Test that skip_special_tokens can be set and retrieved
let common_ext = CommonExt::builder()
.skip_special_tokens(true)
.build()
.unwrap();

assert_eq!(common_ext.skip_special_tokens, Some(true));

let common_ext = CommonExt::builder()
.skip_special_tokens(false)
.build()
.unwrap();

assert_eq!(common_ext.skip_special_tokens, Some(false));
}

#[test]
fn test_skip_special_tokens_serialization() {
// Test that skip_special_tokens can be serialized and deserialized
let common_ext = CommonExt::builder()
.skip_special_tokens(true)
.build()
.unwrap();

let json = serde_json::to_string(&common_ext).unwrap();
let deserialized: CommonExt = serde_json::from_str(&json).unwrap();

assert_eq!(deserialized.skip_special_tokens, Some(true));

// Test with false value
let common_ext = CommonExt::builder()
.skip_special_tokens(false)
.build()
.unwrap();

let json = serde_json::to_string(&common_ext).unwrap();
let deserialized: CommonExt = serde_json::from_str(&json).unwrap();

assert_eq!(deserialized.skip_special_tokens, Some(false));

// Test that None is not serialized (skip_serializing_if = "Option::is_none")
let common_ext = CommonExt::builder().build().unwrap();
let json = serde_json::to_string(&common_ext).unwrap();
assert!(!json.contains("skip_special_tokens"));
}
}
68 changes: 67 additions & 1 deletion lib/llm/src/protocols/openai/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ impl CommonExtProvider for NvCreateCompletionRequest {
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}

fn get_skip_special_tokens(&self) -> Option<bool> {
self.common.skip_special_tokens
}
}

impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
Expand Down Expand Up @@ -364,7 +368,7 @@ impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
}

fn get_skip_special_tokens(&self) -> Option<bool> {
None
CommonExtProvider::get_skip_special_tokens(self)
}

fn get_formatted_prompt(&self) -> Option<bool> {
Expand Down Expand Up @@ -407,3 +411,65 @@ impl ValidateRequest for NvCreateCompletionRequest {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::OutputOptionsProvider;
use serde_json::json;

#[test]
fn test_skip_special_tokens_none() {
let json_str = json!({
"model": "test-model",
"prompt": "Hello, world!"
});

let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

assert_eq!(request.common.skip_special_tokens, None);

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, None);
}

#[test]
fn test_skip_special_tokens_true_propagates() {
let json_str = json!({
"model": "test-model",
"prompt": "Hello, world!",
"skip_special_tokens": true
});

let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, Some(true));
}

#[test]
fn test_skip_special_tokens_false_propagates() {
let json_str = json!({
"model": "test-model",
"prompt": "Hello, world!",
"skip_special_tokens": false
});

let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");

let output_options = request
.extract_output_options()
.expect("Failed to extract output options");

assert_eq!(output_options.skip_special_tokens, Some(false));
}
}
Loading
Loading