Skip to content

Commit ae6837f

Browse files
Removed tokenizer/processor creation from example scripts (#4211)
1 parent 56a8f11 commit ae6837f

File tree

11 files changed

+13
-48
lines changed

11 files changed

+13
-48
lines changed

examples/scripts/dpo_vlm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
script_args, training_args, model_args = parser.parse_args_and_config()
8686

8787
################
88-
# Model & Tokenizer
88+
# Model & Processor
8989
################
9090
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
9191

@@ -117,7 +117,6 @@
117117
processor = AutoProcessor.from_pretrained(
118118
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, do_image_splitting=False
119119
)
120-
tokenizer = processor.tokenizer
121120

122121
# Set up the chat template
123122
if model.config.model_type == "idefics2":
@@ -127,8 +126,6 @@
127126
elif model.config.model_type == "llava":
128127
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
129128

130-
if tokenizer.pad_token is None:
131-
tokenizer.pad_token = tokenizer.eos_token
132129
if script_args.ignore_bias_buffers:
133130
# torch distributed hack
134131
model._ddp_params_and_buffers_to_ignore = [
@@ -153,7 +150,6 @@
153150
args=training_args,
154151
train_dataset=dataset[script_args.dataset_train_split],
155152
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
156-
processing_class=processor,
157153
peft_config=peft_config,
158154
)
159155

examples/scripts/grpo_vlm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
9595
script_args, training_args, model_args = parser.parse_args_and_config()
9696
################
97-
# Model & Processor
97+
# Model
9898
################
9999
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
100100
training_args.model_init_kwargs = dict(

examples/scripts/gspo_vlm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
8282
script_args, training_args, model_args = parser.parse_args_and_config()
8383
################
84-
# Model & Processor
84+
# Model
8585
################
8686
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
8787
training_args.model_init_kwargs = dict(

examples/scripts/mpo_vlm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
import torch
4747
from datasets import load_dataset
4848
from PIL import Image
49-
from transformers import AutoModelForImageTextToText, AutoProcessor
49+
from transformers import AutoModelForImageTextToText
5050

5151
from trl import (
5252
DPOConfig,
@@ -97,9 +97,6 @@
9797
)
9898
else:
9999
ref_model = None
100-
processor = AutoProcessor.from_pretrained(
101-
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
102-
)
103100

104101
################
105102
# Dataset
@@ -135,7 +132,6 @@ def ensure_rgb(example):
135132
args=training_args,
136133
train_dataset=train_dataset,
137134
eval_dataset=test_dataset,
138-
processing_class=processor,
139135
peft_config=peft_config,
140136
)
141137

examples/scripts/reward_modeling.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
import torch
5858
from accelerate import logging
5959
from datasets import load_dataset
60-
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
60+
from transformers import AutoModelForSequenceClassification, HfArgumentParser
6161

6262
from trl import (
6363
ModelConfig,
@@ -97,18 +97,9 @@
9797
model_kwargs["device_map"] = get_kbit_device_map()
9898
model_kwargs["quantization_config"] = quantization_config
9999

100-
tokenizer = AutoTokenizer.from_pretrained(
101-
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
102-
)
103100
model = AutoModelForSequenceClassification.from_pretrained(
104101
model_args.model_name_or_path, num_labels=1, trust_remote_code=model_args.trust_remote_code, **model_kwargs
105102
)
106-
# Align padding tokens between tokenizer and model
107-
model.config.pad_token_id = tokenizer.pad_token_id
108-
109-
# If post-training a base model, use ChatML as the default template
110-
if tokenizer.chat_template is None:
111-
model, tokenizer = setup_chat_format(model, tokenizer)
112103

113104
if model_args.use_peft and model_args.lora_task_type != "SEQ_CLS":
114105
logger.warning(
@@ -126,7 +117,6 @@
126117
##########
127118
trainer = RewardTrainer(
128119
model=model,
129-
processing_class=tokenizer,
130120
args=training_args,
131121
train_dataset=dataset[script_args.dataset_train_split],
132122
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,

examples/scripts/rloo_vlm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
parser = TrlParser((ScriptArguments, RLOOConfig, ModelConfig))
9595
script_args, training_args, model_args = parser.parse_args_and_config()
9696
################
97-
# Model & Processor
97+
# Model
9898
################
9999
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
100100
training_args.model_init_kwargs = dict(

examples/scripts/sft_gpt_oss.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
import os
5353

5454
from datasets import load_dataset
55-
from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config
55+
from transformers import AutoModelForCausalLM, Mxfp4Config
5656

5757
from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_peft_config
5858

@@ -62,7 +62,7 @@
6262

6363

6464
def main(script_args, training_args, model_args):
65-
# Load model & tokenizer
65+
# Load model
6666
quantization_config = Mxfp4Config(dequantize=True)
6767
model_kwargs = dict(
6868
revision=model_args.model_revision,
@@ -75,7 +75,6 @@ def main(script_args, training_args, model_args):
7575

7676
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
7777

78-
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
7978

8079
# Load dataset
8180
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
@@ -86,7 +85,6 @@ def main(script_args, training_args, model_args):
8685
args=training_args,
8786
train_dataset=dataset[script_args.dataset_train_split],
8887
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
89-
processing_class=tokenizer,
9088
peft_config=get_peft_config(model_args),
9189
)
9290

examples/scripts/sft_video_llm.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from datasets import load_dataset
6363
from peft import LoraConfig
6464
from qwen_vl_utils import process_vision_info
65-
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor
65+
from transformers import AutoModelForImageTextToText, BitsAndBytesConfig, Qwen2VLProcessor
6666

6767
from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map
6868

@@ -224,10 +224,6 @@ class CustomScriptArguments(ScriptArguments):
224224
model.config.use_reentrant = False
225225
model.enable_input_require_grads()
226226

227-
processor = AutoProcessor.from_pretrained(
228-
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
229-
)
230-
231227
# Prepare dataset
232228
prepared_dataset = [prepare_dataset(example, script_args.video_cache_dir) for example in dataset]
233229

@@ -238,7 +234,6 @@ class CustomScriptArguments(ScriptArguments):
238234
train_dataset=prepared_dataset,
239235
data_collator=collate_fn,
240236
peft_config=peft_config,
241-
processing_class=processor,
242237
)
243238

244239
# Train model
@@ -248,8 +243,6 @@ class CustomScriptArguments(ScriptArguments):
248243
trainer.save_model(training_args.output_dir)
249244
if training_args.push_to_hub:
250245
trainer.push_to_hub(dataset_name=script_args.dataset_name)
251-
if trainer.accelerator.is_main_process:
252-
processor.push_to_hub(training_args.hub_model_id)
253246

254247
# Cleanup
255248
del model

examples/scripts/sft_vlm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
training_args.max_length = None
8383

8484
################
85-
# Model, Tokenizer & Processor
85+
# Model
8686
################
8787
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
8888
model_kwargs = dict(

examples/scripts/sft_vlm_gemma3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def main():
147147
training_args.max_length = None
148148

149149
################
150-
# Model, Tokenizer & Processor
150+
# Model
151151
################
152152
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
153153
model_kwargs = dict(

0 commit comments

Comments
 (0)