Skip to content

Commit 016c85d

Browse files
authored
fix: improve tokenization (#1113)
1 parent 8f579d6 commit 016c85d

File tree

5 files changed

+73
-25
lines changed

5 files changed

+73
-25
lines changed

examples/multi_agent/agent_system.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,21 @@ async def generate_response(args, prompt, key):
2020

2121
url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
2222

23-
prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False)
23+
if args.apply_chat_template:
24+
assert isinstance(prompt, list), "prompt should be a list when apply_chat_template is True"
25+
prompt_text = tokenizer.apply_chat_template(
26+
prompt,
27+
tokenize=False,
28+
add_generation_prompt=True, # Add generation prompt for the assistant
29+
**(args.apply_chat_template_kwargs or {}),
30+
)
31+
sample.prompt = prompt_text
32+
else:
33+
assert isinstance(prompt, str), "prompt should be a string when apply_chat_template is False"
34+
sample.prompt = prompt
35+
prompt_token_ids = tokenizer(sample.prompt, add_special_tokens=False)["input_ids"]
2436
sample.tokens = prompt_token_ids
25-
sample.prompt = prompt
26-
input_token_ids = prompt_token_ids
27-
prompt_length = len(input_token_ids)
37+
prompt_length = len(prompt_token_ids)
2838
current_sampling_params = deepcopy(sampling_params)
2939
current_sampling_params["max_new_tokens"] = min(
3040
sampling_params["max_new_tokens"], max_context_length - prompt_length
@@ -33,7 +43,7 @@ async def generate_response(args, prompt, key):
3343
if current_sampling_params["max_new_tokens"] <= 0:
3444
return None
3545

36-
payload = {"input_ids": input_token_ids, "sampling_params": current_sampling_params, "return_logprob": True}
46+
payload = {"input_ids": prompt_token_ids, "sampling_params": current_sampling_params, "return_logprob": True}
3747

3848
output = await post(url, payload)
3949

examples/search-r1/generate_with_search.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,15 +151,26 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
151151

152152
# Handle partial rollout samples: continue generation from existing response
153153
prompt = sample.prompt
154-
prompt_tokens_ids = state.tokenizer(sample.prompt, add_special_tokens=False)["input_ids"]
154+
if args.apply_chat_template:
155+
assert isinstance(prompt, list), "prompt should be a list when apply_chat_template is True"
156+
prompt_text = state.tokenizer.apply_chat_template(
157+
prompt,
158+
tokenize=False,
159+
add_generation_prompt=True, # Add generation prompt for the assistant
160+
**(args.apply_chat_template_kwargs or {}),
161+
)
162+
else:
163+
assert isinstance(prompt, str), "prompt should be a string when apply_chat_template is False"
164+
prompt_text = prompt
165+
prompt_tokens_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
155166
response = ""
156167
response_token_ids = []
157168
loss_mask = []
158169
rollout_log_probs = [] if SEARCH_R1_CONFIGS["return_logprob"] else None
159170

160171
for _turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]):
161172
payload = {
162-
"text": prompt + response,
173+
"text": prompt_text + response,
163174
"sampling_params": sampling_params,
164175
}
165176
# Add log probability collection if enabled

slime/rollout/sglang_rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A
101101
state.tokenizer,
102102
state.processor,
103103
sample.metadata,
104+
args.apply_chat_template,
104105
args.apply_chat_template_kwargs,
105106
)
106107

slime/utils/data.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,32 @@ def _parse_generalized_path(s: str):
4949
return s, None
5050

5151

52-
def _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs):
52+
def _should_skip_prompt(
53+
prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs
54+
):
5355
if max_length is None:
5456
return False
5557

5658
from slime.utils.processing_utils import prepare_model_inputs
5759

58-
input_ids, _ = prepare_model_inputs(prompt, tokenizer, processor, None, apply_chat_template_kwargs)
60+
input_ids, _ = prepare_model_inputs(
61+
prompt, tokenizer, processor, metadata, apply_chat_template, apply_chat_template_kwargs
62+
)
5963
return len(input_ids) > max_length
6064

6165

62-
def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None):
63-
messages = data.get(prompt_key)
66+
def _build_messages(data: dict, prompt_key: str, as_conversation: bool, multimodal_keys: dict = None):
67+
prompt = data.get(prompt_key)
6468

65-
if isinstance(messages, str):
66-
messages = [{"role": "user", "content": messages}]
69+
if isinstance(prompt, str):
70+
# If prompt is a string and we don't apply chat template, return the prompt as is.
71+
if not as_conversation:
72+
return prompt
73+
else:
74+
prompt = [{"role": "user", "content": prompt}]
6775

6876
if multimodal_keys:
77+
assert as_conversation, "as_conversation must be True when multimodal_keys is not None"
6978
# Build mapping: placeholder -> (MultimodalType, content_list)
7079
multimodals = {}
7180
for type_name, data_key in multimodal_keys.items():
@@ -75,7 +84,7 @@ def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None):
7584

7685
pattern = "(" + "|".join(re.escape(p) for p in multimodals.keys()) + ")"
7786

78-
for message in messages:
87+
for message in prompt:
7988
if isinstance(message["content"], str):
8089
content_list = []
8190
for segment in re.split(pattern, message["content"]):
@@ -105,7 +114,7 @@ def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None):
105114
f"Unsupported content type: {type(message['content'])}, expected str or list of dicts"
106115
)
107116

108-
return messages
117+
return prompt
109118

110119

111120
class Dataset:
@@ -127,7 +136,8 @@ def __init__(
127136
):
128137
self.origin_samples = []
129138
for data in read_file(path):
130-
prompt = _build_messages(data, prompt_key, multimodal_keys)
139+
as_conversation = apply_chat_template
140+
prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys)
131141

132142
metadata = data.get(metadata_key) or {}
133143
if tool_key is not None and tool_key in data:
@@ -140,7 +150,9 @@ def __init__(
140150
metadata["tools"] = tools
141151

142152
# TODO: this is slow.
143-
if _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs):
153+
if _should_skip_prompt(
154+
prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs
155+
):
144156
continue
145157

146158
self.origin_samples.append(

slime/utils/processing_utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import io
33
import logging
44

5+
import numpy as np
56
from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin
67

78
logger = logging.getLogger(__name__)
@@ -25,7 +26,9 @@ def load_processor(name_or_path: str, **kwargs):
2526
return proc
2627

2728

28-
def prepare_model_inputs(prompt, tokenizer, processor=None, metadata=None, apply_chat_template_kwargs=None):
29+
def prepare_model_inputs(
30+
prompt, tokenizer, processor=None, metadata=None, apply_chat_template=False, apply_chat_template_kwargs=None
31+
):
2932
"""Prepare all inputs for model inference.
3033
3134
Returns:
@@ -34,13 +37,24 @@ def prepare_model_inputs(prompt, tokenizer, processor=None, metadata=None, apply
3437
- extra_info: Dict with 'images', 'videos', 'multimodal_inputs' (or empty dict)
3538
"""
3639
tools = metadata.get("tools") if metadata else None
37-
text_prompt = tokenizer.apply_chat_template(
38-
prompt,
39-
tools=tools,
40-
tokenize=False,
41-
add_generation_prompt=True,
42-
**(apply_chat_template_kwargs or {}),
43-
)
40+
if isinstance(prompt, (list, np.ndarray)):
41+
assert (
42+
apply_chat_template
43+
), f"apply_chat_template must be True when prompt is a list or numpy array, current prompt is {prompt}"
44+
text_prompt = tokenizer.apply_chat_template(
45+
prompt,
46+
tools=tools,
47+
tokenize=False,
48+
add_generation_prompt=True,
49+
**(apply_chat_template_kwargs or {}),
50+
)
51+
elif isinstance(prompt, str):
52+
assert (
53+
not apply_chat_template
54+
), f"apply_chat_template must be False when prompt is a string, current prompt is {prompt}"
55+
text_prompt = prompt
56+
else:
57+
raise ValueError(f"Invalid prompt type: {type(prompt)}, current prompt is {prompt}")
4458

4559
if not processor:
4660
input_ids = tokenizer.encode(text_prompt, add_special_tokens=False)

0 commit comments

Comments
 (0)