Skip to content

Commit 441473c

Browse files
authored
feat: Add support for skip_special_tokens parameter in v1/completions and v1/chat/completions endpoints (#4175)
1 parent 14af074 commit 441473c

File tree

6 files changed

+208
-6
lines changed

6 files changed

+208
-6
lines changed

lib/llm/src/backend.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,13 @@ impl Backend {
9494
stream: ManyOut<ExecutionOutputStream>,
9595
prompt_token_ids: &[TokenIdType],
9696
stop_conditions: StopConditions,
97+
skip_special_tokens: bool,
9798
) -> anyhow::Result<DecoderUnfoldState> {
9899
let Some(tokenizer) = self.tokenizer.as_ref() else {
99100
anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
100101
};
101102
let decoder = Decoder::new(
102-
tokenizer.decode_stream(prompt_token_ids, false),
103+
tokenizer.decode_stream(prompt_token_ids, skip_special_tokens),
103104
stop_conditions,
104105
);
105106

@@ -129,10 +130,18 @@ impl
129130

130131
let prompt_token_ids = request.token_ids.clone();
131132

133+
// TODO: Consider updating default to true to match behavior of other frameworks
134+
let skip_special_tokens = request.output_options.skip_special_tokens.unwrap_or(false);
135+
132136
let next_stream = next.generate(request).await?;
133137

134138
let context = next_stream.context();
135-
let state = self.decoder(next_stream, &prompt_token_ids, stop_conditions)?;
139+
let state = self.decoder(
140+
next_stream,
141+
&prompt_token_ids,
142+
stop_conditions,
143+
skip_special_tokens,
144+
)?;
136145

137146
let processed_stream = stream::unfold(state, |mut state| async move {
138147
match state.stream.next().await {

lib/llm/src/protocols/common.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,8 +473,6 @@ pub struct OutputOptions {
473473
pub prompt_logprobs: Option<u32>,
474474

475475
/// Whether to skip special tokens in the output.
476-
/// spaces_between_special_tokens: Whether to add spaces between special
477-
/// tokens in the output. Defaults to True.
478476
pub skip_special_tokens: Option<bool>,
479477

480478
/// If true, the Context object will contain the prompt that was pass to

lib/llm/src/protocols/openai/chat_completions.rs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
198198
fn get_include_stop_str_in_output(&self) -> Option<bool> {
199199
self.common.include_stop_str_in_output
200200
}
201+
202+
fn get_skip_special_tokens(&self) -> Option<bool> {
203+
self.common.skip_special_tokens
204+
}
201205
}
202206

203207
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
@@ -263,7 +267,7 @@ impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest {
263267
}
264268

265269
fn get_skip_special_tokens(&self) -> Option<bool> {
266-
None
270+
CommonExtProvider::get_skip_special_tokens(self)
267271
}
268272

269273
fn get_formatted_prompt(&self) -> Option<bool> {
@@ -316,3 +320,53 @@ impl ValidateRequest for NvCreateChatCompletionRequest {
316320
Ok(())
317321
}
318322
}
323+
324+
#[cfg(test)]
325+
mod tests {
326+
use super::*;
327+
use crate::protocols::common::OutputOptionsProvider;
328+
use serde_json::json;
329+
330+
#[test]
331+
fn test_skip_special_tokens_none() {
332+
let json_str = json!({
333+
"model": "test-model",
334+
"messages": [
335+
{"role": "user", "content": "Hello"}
336+
]
337+
});
338+
339+
let request: NvCreateChatCompletionRequest =
340+
serde_json::from_value(json_str).expect("Failed to deserialize request");
341+
342+
assert_eq!(request.common.skip_special_tokens, None);
343+
344+
let output_options = request
345+
.extract_output_options()
346+
.expect("Failed to extract output options");
347+
348+
assert_eq!(output_options.skip_special_tokens, None);
349+
}
350+
351+
#[test]
352+
fn test_skip_special_tokens_propagates() {
353+
for skip_value in [true, false] {
354+
let json_str = json!({
355+
"model": "test-model",
356+
"messages": [
357+
{"role": "user", "content": "Hello"}
358+
],
359+
"skip_special_tokens": skip_value
360+
});
361+
362+
let request: NvCreateChatCompletionRequest =
363+
serde_json::from_value(json_str).expect("Failed to deserialize request");
364+
365+
let output_options = request
366+
.extract_output_options()
367+
.expect("Failed to extract output options");
368+
369+
assert_eq!(output_options.skip_special_tokens, Some(skip_value));
370+
}
371+
}
372+
}

lib/llm/src/protocols/openai/common_ext.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ pub struct CommonExt {
7272
#[builder(default, setter(strip_option))]
7373
#[allow(unused)] // Not used
7474
pub guided_whitespace_pattern: Option<String>,
75+
76+
/// Whether to skip special tokens in the decoded output.
77+
/// When true, special tokens (like EOS, BOS, PAD) are removed from the output text.
78+
/// When false, special tokens are included in the output text.
79+
/// Defaults to false if not specified.
80+
#[serde(default, skip_serializing_if = "Option::is_none")]
81+
#[builder(default, setter(strip_option))]
82+
pub skip_special_tokens: Option<bool>,
7583
}
7684

7785
impl CommonExt {
@@ -99,6 +107,9 @@ pub trait CommonExtProvider {
99107
fn get_min_p(&self) -> Option<f32>;
100108
fn get_repetition_penalty(&self) -> Option<f32>;
101109
fn get_include_stop_str_in_output(&self) -> Option<bool>;
110+
111+
/// Output Options
112+
fn get_skip_special_tokens(&self) -> Option<bool>;
102113
}
103114

104115
#[cfg(test)]
@@ -120,6 +131,7 @@ mod tests {
120131
assert_eq!(common_ext.guided_choice, None);
121132
assert_eq!(common_ext.guided_decoding_backend, None);
122133
assert_eq!(common_ext.include_stop_str_in_output, None);
134+
assert_eq!(common_ext.skip_special_tokens, None);
123135
}
124136

125137
#[test]
@@ -135,6 +147,7 @@ mod tests {
135147
.guided_grammar("grammar".to_string())
136148
.guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
137149
.guided_decoding_backend("backend".to_string())
150+
.skip_special_tokens(false)
138151
.build()
139152
.unwrap();
140153

@@ -157,6 +170,7 @@ mod tests {
157170
common_ext.guided_decoding_backend,
158171
Some("backend".to_string())
159172
);
173+
assert_eq!(common_ext.skip_special_tokens, Some(false));
160174
}
161175

162176
#[test]
@@ -190,6 +204,7 @@ mod tests {
190204
guided_choice: None,
191205
guided_decoding_backend: None,
192206
guided_whitespace_pattern: None,
207+
skip_special_tokens: None,
193208
};
194209
assert!(common_ext.validate().is_ok());
195210
}
@@ -219,4 +234,52 @@ mod tests {
219234
assert_eq!(common_ext.include_stop_str_in_output, None);
220235
assert!(common_ext.validate().is_ok());
221236
}
237+
238+
#[test]
239+
fn test_skip_special_tokens_field() {
240+
// Test that skip_special_tokens can be set and retrieved
241+
let common_ext = CommonExt::builder()
242+
.skip_special_tokens(true)
243+
.build()
244+
.unwrap();
245+
246+
assert_eq!(common_ext.skip_special_tokens, Some(true));
247+
248+
let common_ext = CommonExt::builder()
249+
.skip_special_tokens(false)
250+
.build()
251+
.unwrap();
252+
253+
assert_eq!(common_ext.skip_special_tokens, Some(false));
254+
}
255+
256+
#[test]
257+
fn test_skip_special_tokens_serialization() {
258+
// Test that skip_special_tokens can be serialized and deserialized
259+
let common_ext = CommonExt::builder()
260+
.skip_special_tokens(true)
261+
.build()
262+
.unwrap();
263+
264+
let json = serde_json::to_string(&common_ext).unwrap();
265+
let deserialized: CommonExt = serde_json::from_str(&json).unwrap();
266+
267+
assert_eq!(deserialized.skip_special_tokens, Some(true));
268+
269+
// Test with false value
270+
let common_ext = CommonExt::builder()
271+
.skip_special_tokens(false)
272+
.build()
273+
.unwrap();
274+
275+
let json = serde_json::to_string(&common_ext).unwrap();
276+
let deserialized: CommonExt = serde_json::from_str(&json).unwrap();
277+
278+
assert_eq!(deserialized.skip_special_tokens, Some(false));
279+
280+
// Test that None is not serialized (skip_serializing_if = "Option::is_none")
281+
let common_ext = CommonExt::builder().build().unwrap();
282+
let json = serde_json::to_string(&common_ext).unwrap();
283+
assert!(!json.contains("skip_special_tokens"));
284+
}
222285
}

lib/llm/src/protocols/openai/completions.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ impl CommonExtProvider for NvCreateCompletionRequest {
222222
fn get_include_stop_str_in_output(&self) -> Option<bool> {
223223
self.common.include_stop_str_in_output
224224
}
225+
226+
fn get_skip_special_tokens(&self) -> Option<bool> {
227+
self.common.skip_special_tokens
228+
}
225229
}
226230

227231
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
@@ -397,7 +401,7 @@ impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
397401
}
398402

399403
fn get_skip_special_tokens(&self) -> Option<bool> {
400-
None
404+
CommonExtProvider::get_skip_special_tokens(self)
401405
}
402406

403407
fn get_formatted_prompt(&self) -> Option<bool> {
@@ -444,3 +448,49 @@ impl ValidateRequest for NvCreateCompletionRequest {
444448
Ok(())
445449
}
446450
}
451+
452+
#[cfg(test)]
453+
mod tests {
454+
use super::*;
455+
use crate::protocols::common::OutputOptionsProvider;
456+
use serde_json::json;
457+
458+
#[test]
459+
fn test_skip_special_tokens_none() {
460+
let json_str = json!({
461+
"model": "test-model",
462+
"prompt": "Hello, world!"
463+
});
464+
465+
let request: NvCreateCompletionRequest =
466+
serde_json::from_value(json_str).expect("Failed to deserialize request");
467+
468+
assert_eq!(request.common.skip_special_tokens, None);
469+
470+
let output_options = request
471+
.extract_output_options()
472+
.expect("Failed to extract output options");
473+
474+
assert_eq!(output_options.skip_special_tokens, None);
475+
}
476+
477+
#[test]
478+
fn test_skip_special_tokens_propagates() {
479+
for skip_value in [true, false] {
480+
let json_str = json!({
481+
"model": "test-model",
482+
"prompt": "Hello, world!",
483+
"skip_special_tokens": skip_value
484+
});
485+
486+
let request: NvCreateCompletionRequest =
487+
serde_json::from_value(json_str).expect("Failed to deserialize request");
488+
489+
let output_options = request
490+
.extract_output_options()
491+
.expect("Failed to extract output options");
492+
493+
assert_eq!(output_options.skip_special_tokens, Some(skip_value));
494+
}
495+
}
496+
}

lib/llm/tests/tokenizers.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,31 @@ fn test_long_sequence_incremental_decode_with_prefill() {
178178
assert_eq!(output.trim(), output_text.to_string());
179179
}
180180
}
181+
182+
#[test]
183+
fn test_decode_with_skip_special_tokens() {
184+
let tokenizer = HuggingFaceTokenizer::from_file(TINYLLAMA_TOKENIZER_PATH)
185+
.expect("Failed to load remote HuggingFace tokenizer");
186+
187+
// Create a sequence with special tokens:
188+
// <s> (token_id: 1) + "Hello world" + </s> (token_id: 2)
189+
let text = "Hello world";
190+
let encoding = tokenizer.encode(text).expect("Failed to encode text");
191+
let mut token_ids = vec![1]; // <s>
192+
token_ids.extend(encoding.token_ids());
193+
token_ids.push(2); // </s>
194+
195+
// Decode with skip_special_tokens = false (should keep special tokens)
196+
let decoded_with_special = tokenizer
197+
.decode(&token_ids, false)
198+
.expect("Failed to decode with skip_special_tokens=false");
199+
200+
// Decode with skip_special_tokens = true (should remove special tokens)
201+
let decoded_without_special = tokenizer
202+
.decode(&token_ids, true)
203+
.expect("Failed to decode with skip_special_tokens=true");
204+
205+
// Validate exact matches on the entire decoded strings
206+
assert_eq!(decoded_with_special, "<s> Hello world</s>");
207+
assert_eq!(decoded_without_special, "Hello world");
208+
}

0 commit comments

Comments
 (0)