Skip to content

Commit 2e40a48

Browse files
authored
Revert data processing of VLM (#1232)
1 parent eaa6530 commit 2e40a48

File tree

7 files changed

+62
-111
lines changed

7 files changed

+62
-111
lines changed

.github/workflows/pr-test.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363

6464
- name: Install
6565
shell: bash
66-
run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages
66+
run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages
6767

6868
- name: Execute
6969
shell: bash
@@ -107,7 +107,7 @@ jobs:
107107

108108
- name: Install
109109
shell: bash
110-
run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages
110+
run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages
111111

112112
- name: Execute
113113
shell: bash
@@ -151,7 +151,7 @@ jobs:
151151

152152
- name: Install
153153
shell: bash
154-
run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages
154+
run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages
155155

156156
- name: Execute
157157
shell: bash
@@ -195,7 +195,7 @@ jobs:
195195

196196
- name: Install
197197
shell: bash
198-
run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages
198+
run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages
199199

200200
- name: Execute
201201
shell: bash

.github/workflows/pr-test.yml.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494

9595
- name: Install
9696
shell: bash
97-
run: cd $GITHUB_WORKSPACE && pip install -e . --break-system-packages
97+
run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages
9898

9999
- name: Execute
100100
shell: bash

slime/ray/rollout.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from slime.utils.metric_checker import MetricChecker
2121
from slime.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix
2222
from slime.utils.misc import load_function
23+
from slime.utils.processing_utils import load_processor
2324
from slime.utils.ray_utils import Box
2425
from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions
2526
from slime.utils.tracking_utils import init_tracking
@@ -77,6 +78,7 @@ def __init__(self, args, pg):
7778
self._metric_checker = MetricChecker.maybe_create(args)
7879
if self.args.use_fault_tolerance:
7980
self._health_monitor = RolloutHealthMonitor(self, args)
81+
self.processor = None
8082

8183
def dispose(self):
8284
if self._metric_checker is not None:
@@ -275,7 +277,17 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
275277
train_data["metadata"] = [sample.train_metadata for sample in samples]
276278

277279
if samples[0].multimodal_inputs is not None:
278-
train_data["multimodal_inputs"] = [sample.multimodal_inputs for sample in samples]
280+
if self.processor is None:
281+
self.processor = load_processor(self.args.hf_checkpoint, trust_remote_code=True)
282+
train_data["multimodal_inputs"] = []
283+
for sample in samples:
284+
# Get input IDs with full prompt (text + multimodal)
285+
processor_output = self.processor(text=sample.prompt, **sample.multimodal_inputs)
286+
287+
# Extract multimodal tokens (exclude text-related tokens)
288+
train_data["multimodal_inputs"].append(
289+
{k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]}
290+
)
279291

280292
if "teacher_log_probs" in samples[0].__dict__:
281293
train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples]

slime/rollout/sft_rollout.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from slime.utils.mask_utils import MultiTurnLossMaskGenerator
4-
from slime.utils.processing_utils import load_processor, load_tokenizer, prepare_model_inputs
4+
from slime.utils.processing_utils import load_processor, load_tokenizer
55

66
__all__ = ["generate_rollout"]
77

@@ -46,18 +46,7 @@ def generate_rollout(args, rollout_id, data_buffer, evaluation=False):
4646
messages = sample.prompt
4747
tools = sample.metadata.get("tools", None)
4848

49-
input_ids, extra_info = prepare_model_inputs(
50-
messages, TOKENIZER, PROCESSOR, sample.metadata, args.apply_chat_template, args.apply_chat_template_kwargs
51-
)
52-
53-
has_multimodal = bool(extra_info.get("images") or extra_info.get("videos"))
54-
if has_multimodal:
55-
sample.multimodal_inputs = extra_info["multimodal_inputs"]
56-
token_ids, loss_mask = MASK_GENERATOR.get_loss_mask_with_multimodal_alignment(
57-
messages, input_ids, tools=tools
58-
)
59-
else:
60-
token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools)
49+
token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools)
6150

6251
response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0]
6352

slime/rollout/sglang_rollout.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,7 @@
2020
from slime.utils.http_utils import get, post
2121
from slime.utils.mask_utils import get_response_lengths
2222
from slime.utils.misc import SingletonMeta, load_function
23-
from slime.utils.processing_utils import (
24-
encode_image_for_rollout_engine,
25-
load_processor,
26-
load_tokenizer,
27-
prepare_model_inputs,
28-
)
23+
from slime.utils.processing_utils import encode_image_for_rollout_engine, load_processor, load_tokenizer
2924
from slime.utils.types import Sample
3025

3126
from .rm_hub import async_rm, batched_async_rm
@@ -90,26 +85,21 @@ def submit_generate_tasks(self, samples: list[list[Sample]]) -> None:
9085

9186
async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample:
9287
"""Generate using traditional SGLang router with token-based workflow"""
88+
if args.ci_test:
89+
assert isinstance(sample.prompt, str)
90+
9391
state = GenerateState(args)
9492
url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
9593

9694
assert (
9795
sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED
9896
), f"Sample status is {sample.status}"
9997

100-
prompt_ids, extra_info = prepare_model_inputs(
101-
sample.prompt,
102-
state.tokenizer,
103-
state.processor,
104-
sample.metadata,
105-
args.apply_chat_template,
106-
args.apply_chat_template_kwargs,
107-
)
108-
109-
sample.prompt = extra_info.get("formatted_prompt", sample.prompt)
110-
image_data = extra_info.get("images", [])
111-
video_data = extra_info.get("videos", [])
112-
multimodal_inputs = extra_info.get("multimodal_inputs", None)
98+
if state.processor:
99+
processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs)
100+
prompt_ids = processor_output["input_ids"][0]
101+
else:
102+
prompt_ids = state.tokenizer.encode(sample.prompt, add_special_tokens=False)
113103

114104
if len(sample.response) > 0:
115105
sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids)
@@ -130,12 +120,9 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A
130120
if args.use_rollout_routing_replay:
131121
payload["return_routed_experts"] = True
132122

133-
if image_data:
123+
if sample.multimodal_inputs and sample.multimodal_inputs["images"]:
124+
image_data = sample.multimodal_inputs["images"]
134125
payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data]
135-
sample.multimodal_inputs = multimodal_inputs
136-
137-
if video_data:
138-
raise NotImplementedError("Video data is not supported yet")
139126

140127
# Use existing tokens for multi-turn or tokenize the new prompt
141128
if len(sample.response) > 0:

slime/utils/data.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,16 @@ def _parse_generalized_path(s: str):
4949
return s, None
5050

5151

52-
def _should_skip_prompt(
53-
prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs
54-
):
52+
def _should_skip_prompt(formatted_prompt: str, tokenizer, processor, max_length, multimodal_inputs=None):
5553
if max_length is None:
5654
return False
5755

58-
from slime.utils.processing_utils import prepare_model_inputs
56+
if processor:
57+
processor_output = processor(text=formatted_prompt, **multimodal_inputs)
58+
input_ids = processor_output["input_ids"][0]
59+
else:
60+
input_ids = tokenizer.encode(formatted_prompt, add_special_tokens=False)
5961

60-
input_ids, _ = prepare_model_inputs(
61-
prompt, tokenizer, processor, metadata, apply_chat_template, apply_chat_template_kwargs
62-
)
6362
return len(input_ids) > max_length
6463

6564

@@ -140,6 +139,7 @@ def __init__(
140139
prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys)
141140

142141
metadata = data.get(metadata_key) or {}
142+
tools = None
143143
if tool_key is not None and tool_key in data:
144144
tools = data[tool_key]
145145
if isinstance(tools, str):
@@ -149,17 +149,37 @@ def __init__(
149149
assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead"
150150
metadata["tools"] = tools
151151

152+
if apply_chat_template:
153+
formatted_prompt = tokenizer.apply_chat_template(
154+
prompt,
155+
tools=tools,
156+
tokenize=False,
157+
add_generation_prompt=True,
158+
**(apply_chat_template_kwargs or {}),
159+
)
160+
else:
161+
formatted_prompt = prompt
162+
163+
if processor:
164+
# temporary solution, will write image utils for slime later
165+
from qwen_vl_utils import process_vision_info
166+
167+
assert isinstance(prompt, list)
168+
images, videos = process_vision_info(prompt)
169+
multimodal_inputs = {"images": images, "videos": videos}
170+
else:
171+
multimodal_inputs = None
172+
152173
# TODO: this is slow.
153-
if _should_skip_prompt(
154-
prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs
155-
):
174+
if _should_skip_prompt(formatted_prompt, tokenizer, processor, max_length, multimodal_inputs):
156175
continue
157176

158177
self.origin_samples.append(
159178
Sample(
160-
prompt=prompt,
179+
prompt=formatted_prompt,
161180
label=data[label_key] if label_key is not None else None,
162181
metadata=metadata,
182+
multimodal_inputs=multimodal_inputs,
163183
)
164184
)
165185

slime/utils/processing_utils.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import io
33
import logging
44

5-
import numpy as np
65
from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin
76

87
logger = logging.getLogger(__name__)
@@ -26,62 +25,6 @@ def load_processor(name_or_path: str, **kwargs):
2625
return proc
2726

2827

29-
def prepare_model_inputs(
30-
prompt, tokenizer, processor=None, metadata=None, apply_chat_template=False, apply_chat_template_kwargs=None
31-
):
32-
"""Prepare all inputs for model inference.
33-
34-
Returns:
35-
tuple: (input_ids, extra_info)
36-
- input_ids: Token IDs for the prompt
37-
- extra_info: Dict with 'images', 'videos', 'multimodal_inputs' (or empty dict)
38-
"""
39-
tools = metadata.get("tools") if metadata else None
40-
if isinstance(prompt, (list, np.ndarray)):
41-
assert (
42-
apply_chat_template
43-
), f"apply_chat_template must be True when prompt is a list or numpy array, current prompt is {prompt}"
44-
formatted_prompt = tokenizer.apply_chat_template(
45-
prompt,
46-
tools=tools,
47-
tokenize=False,
48-
add_generation_prompt=True,
49-
**(apply_chat_template_kwargs or {}),
50-
)
51-
elif isinstance(prompt, str):
52-
assert (
53-
not apply_chat_template
54-
), f"apply_chat_template must be False when prompt is a string, current prompt is {prompt}"
55-
formatted_prompt = prompt
56-
else:
57-
raise ValueError(f"Invalid prompt type: {type(prompt)}, current prompt is {prompt}")
58-
59-
if not processor:
60-
input_ids = tokenizer.encode(formatted_prompt, add_special_tokens=False)
61-
return input_ids, {"formatted_prompt": formatted_prompt}
62-
else:
63-
# temporary solution, will write image utils for slime later
64-
from qwen_vl_utils import process_vision_info
65-
66-
images, videos = process_vision_info(prompt)
67-
68-
# Get input IDs with full prompt (text + multimodal)
69-
processor_output = processor(text=formatted_prompt, images=images, videos=videos)
70-
input_ids = processor_output["input_ids"][0]
71-
72-
# Extract multimodal tokens (exclude text-related tokens)
73-
multimodal_inputs = {k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]}
74-
75-
extra_info = {
76-
"formatted_prompt": formatted_prompt,
77-
"images": images,
78-
"videos": videos,
79-
"multimodal_inputs": multimodal_inputs,
80-
}
81-
82-
return input_ids, extra_info
83-
84-
8528
def encode_image_for_rollout_engine(image) -> str:
8629
"""Load an image from path, ensure RGB, encode as JPEG base64 string."""
8730
buffer = io.BytesIO()

0 commit comments

Comments
 (0)