Skip to content

Commit 453db5c

Browse files
authored
🤏 New models for tests (huggingface#2287)
* first commit * uncomment * other tests adaptations * Remove unused variable in test_setup_chat_format * Remove unused import statement * style * Add Bart model * Update BCOTrainerTester class in test_bco_trainer.py * Update model IDs and tokenizers in test files * Add new models and processors * Update model IDs in test files * Fix formatting issue in test_dataset_formatting.py * Refactor dataset formatting in test_dataset_formatting.py * Fix dataset sequence length in SFTTrainerTester * Remove tokenizer * Remove print statement * Add reward_model_path and sft_model_path to PPO trainer * Fix tokenizer padding issue * Add chat template for testing purposes in PaliGemma model * Update PaliGemma model and chat template * Increase learning rate to speed up test * Update model names in run_dpo.sh and run_sft.sh scripts * Update model and dataset names * Fix formatting issue in test_dataset_formatting.py * Fix formatting issue in test_dataset_formatting.py * Remove unused chat template * Update model generation script * additional models * Update model references in test files * Remove unused imports in test_online_dpo_trainer.py * Add is_llm_blender_available import and update reward_tokenizer * Refactor test_online_dpo_trainer.py: Move skipped test case decorator * remove models without chat templates * Update model names in scripts and tests * Update model_id in test_modeling_value_head.py * Update model versions in test files * Fix formatting issue in test_dataset_formatting.py * Update embedding model ID in BCOTrainerTester * Update test_online_dpo_trainer.py with reward model changes * Update expected formatted text in test_dataset_formatting.py * Add reward_tokenizer to TestOnlineDPOTrainer * fix tests * Add SIMPLE_CHAT_TEMPLATE to T5 tokenizer * Fix dummy_text format in test_rloo_trainer.py * Skip outdated test for chatML data collator * Add new vision language models * Commented out unused model IDs in test_vdpo_trainer * Update model and vision configurations in generate_tiny_models.py and test_dpo_trainer.py * Update model and tokenizer references * Don't push if it already exists * Add comment explaining test skip * Fix model_exists function call and add new models * Update LlavaForConditionalGeneration model and processor * `qgallouedec` -> `trl-internal-testing`
1 parent ee3cbe1 commit 453db5c

32 files changed

+482
-275
lines changed

commands/run_dpo.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
33
# but defaults to QLoRA + PEFT
44
OUTPUT_DIR="test_dpo/"
5-
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
5+
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
66
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
77
MAX_STEPS=5
88
BATCH_SIZE=2

commands/run_sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
33
# but defaults to QLoRA + PEFT
44
OUTPUT_DIR="test_sft/"
5-
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
5+
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
66
DATASET_NAME="stanfordnlp/imdb"
77
MAX_STEPS=5
88
BATCH_SIZE=2

docs/source/clis.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ We also recommend you passing a YAML config file to configure your training prot
2323

2424
```yaml
2525
model_name_or_path:
26-
trl-internal-testing/tiny-random-LlamaForCausalLM
26+
Qwen/Qwen2.5-0.5B
2727
dataset_name:
2828
stanfordnlp/imdb
2929
report_to:

examples/cli_configs/example_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# CUDA_VISIBLE_DEVICES: 0
88

99
model_name_or_path:
10-
trl-internal-testing/tiny-random-LlamaForCausalLM
10+
Qwen/Qwen2.5-0.5B
1111
dataset_name:
1212
stanfordnlp/imdb
1313
report_to:

scripts/generate_tiny_models.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This script generates tiny models used in the TRL library for unit tests. It pushes them to the Hub under the
16+
# `trl-internal-testing` organization.
17+
# This script is meant to be run when adding new tiny model to the TRL library.
18+
19+
from huggingface_hub import HfApi, ModelCard
20+
from transformers import (
21+
AutoProcessor,
22+
AutoTokenizer,
23+
BartConfig,
24+
BartModel,
25+
BloomConfig,
26+
BloomForCausalLM,
27+
CLIPVisionConfig,
28+
CohereConfig,
29+
CohereForCausalLM,
30+
DbrxConfig,
31+
DbrxForCausalLM,
32+
FalconMambaConfig,
33+
FalconMambaForCausalLM,
34+
Gemma2Config,
35+
Gemma2ForCausalLM,
36+
GemmaConfig,
37+
GemmaForCausalLM,
38+
GPT2Config,
39+
GPT2LMHeadModel,
40+
GPTNeoXConfig,
41+
GPTNeoXForCausalLM,
42+
Idefics2Config,
43+
Idefics2ForConditionalGeneration,
44+
LlamaConfig,
45+
LlamaForCausalLM,
46+
LlavaConfig,
47+
LlavaForConditionalGeneration,
48+
LlavaNextConfig,
49+
LlavaNextForConditionalGeneration,
50+
MistralConfig,
51+
MistralForCausalLM,
52+
OPTConfig,
53+
OPTForCausalLM,
54+
PaliGemmaConfig,
55+
PaliGemmaForConditionalGeneration,
56+
Phi3Config,
57+
Phi3ForCausalLM,
58+
Qwen2Config,
59+
Qwen2ForCausalLM,
60+
SiglipVisionConfig,
61+
T5Config,
62+
T5ForConditionalGeneration,
63+
)
64+
from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig
65+
66+
67+
ORGANIZATION = "trl-internal-testing"
68+
69+
MODEL_CARD = """
70+
---
71+
library_name: transformers
72+
tags: [trl]
73+
---
74+
75+
# Tiny {model_class_name}
76+
77+
This is a minimal model built for unit tests in the [TRL](https://github.com/huggingface/trl) library.
78+
"""
79+
80+
81+
api = HfApi()
82+
83+
84+
def push_to_hub(model, tokenizer, suffix=None):
85+
model_class_name = model.__class__.__name__
86+
content = MODEL_CARD.format(model_class_name=model_class_name)
87+
model_card = ModelCard(content)
88+
repo_id = f"{ORGANIZATION}/tiny-{model_class_name}"
89+
if suffix is not None:
90+
repo_id += f"-{suffix}"
91+
92+
if api.repo_exists(repo_id):
93+
print(f"Model {repo_id} already exists, skipping")
94+
else:
95+
model.push_to_hub(repo_id)
96+
tokenizer.push_to_hub(repo_id)
97+
model_card.push_to_hub(repo_id)
98+
99+
100+
# Decoder models
101+
for model_id, config_class, model_class, suffix in [
102+
("bigscience/bloomz-560m", BloomConfig, BloomForCausalLM, None),
103+
("CohereForAI/aya-expanse-8b", CohereConfig, CohereForCausalLM, None),
104+
("databricks/dbrx-instruct", DbrxConfig, DbrxForCausalLM, None),
105+
("tiiuae/falcon-7b-instruct", FalconMambaConfig, FalconMambaForCausalLM, None),
106+
("google/gemma-2-2b-it", Gemma2Config, Gemma2ForCausalLM, None),
107+
("google/gemma-7b-it", GemmaConfig, GemmaForCausalLM, None),
108+
("openai-community/gpt2", GPT2Config, GPT2LMHeadModel, None),
109+
("EleutherAI/pythia-14m", GPTNeoXConfig, GPTNeoXForCausalLM, None),
110+
("meta-llama/Meta-Llama-3-8B-Instruct", LlamaConfig, LlamaForCausalLM, "3"),
111+
("meta-llama/Llama-3.1-8B-Instruct", LlamaConfig, LlamaForCausalLM, "3.1"),
112+
("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForCausalLM, "3.2"),
113+
("mistralai/Mistral-7B-Instruct-v0.1", MistralConfig, MistralForCausalLM, "0.1"),
114+
("mistralai/Mistral-7B-Instruct-v0.2", MistralConfig, MistralForCausalLM, "0.2"),
115+
("facebook/opt-1.3b", OPTConfig, OPTForCausalLM, None),
116+
("microsoft/Phi-3.5-mini-instruct", Phi3Config, Phi3ForCausalLM, None),
117+
("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForCausalLM, "2.5"),
118+
("Qwen/Qwen2.5-Coder-0.5B", Qwen2Config, Qwen2ForCausalLM, "2.5-Coder"),
119+
]:
120+
tokenizer = AutoTokenizer.from_pretrained(model_id)
121+
config = config_class(
122+
vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
123+
hidden_size=8,
124+
num_attention_heads=4,
125+
num_key_value_heads=2,
126+
num_hidden_layers=2,
127+
intermediate_size=32,
128+
)
129+
model = model_class(config)
130+
push_to_hub(model, tokenizer, suffix)
131+
132+
133+
# Encoder-decoder models
134+
for model_id, config_class, model_class, suffix in [
135+
("google/flan-t5-small", T5Config, T5ForConditionalGeneration, None),
136+
("facebook/bart-base", BartConfig, BartModel, None),
137+
]:
138+
tokenizer = AutoTokenizer.from_pretrained(model_id)
139+
config = config_class(
140+
vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
141+
d_model=16,
142+
encoder_layers=2,
143+
decoder_layers=2,
144+
d_kv=2,
145+
d_ff=64,
146+
num_layers=6,
147+
num_heads=8,
148+
decoder_start_token_id=0,
149+
is_encoder_decoder=True,
150+
)
151+
model = model_class(config)
152+
push_to_hub(model, tokenizer, suffix)
153+
154+
155+
# Vision Language Models
156+
# fmt: off
157+
for model_id, config_class, text_config_class, vision_config_class, model_class in [
158+
("HuggingFaceM4/idefics2-8b", Idefics2Config, MistralConfig, Idefics2VisionConfig, Idefics2ForConditionalGeneration),
159+
("llava-hf/llava-1.5-7b-hf", LlavaConfig, LlamaConfig, CLIPVisionConfig, LlavaForConditionalGeneration),
160+
("llava-hf/llava-v1.6-mistral-7b-hf", LlavaNextConfig, MistralConfig, CLIPVisionConfig, LlavaNextForConditionalGeneration),
161+
("google/paligemma-3b-pt-224", PaliGemmaConfig, GemmaConfig, SiglipVisionConfig, PaliGemmaForConditionalGeneration),
162+
]:
163+
# fmt: on
164+
processor = AutoProcessor.from_pretrained(model_id)
165+
kwargs = {}
166+
if config_class == PaliGemmaConfig:
167+
kwargs["projection_dim"] = 8
168+
vision_kwargs = {}
169+
if vision_config_class in [CLIPVisionConfig, SiglipVisionConfig]:
170+
vision_kwargs["projection_dim"] = 8
171+
if vision_config_class == CLIPVisionConfig:
172+
vision_kwargs["image_size"] = 336
173+
vision_kwargs["patch_size"] = 14
174+
config = config_class(
175+
text_config=text_config_class(
176+
vocab_size=processor.tokenizer.vocab_size + len(processor.tokenizer.added_tokens_encoder),
177+
hidden_size=8,
178+
num_attention_heads=4,
179+
num_key_value_heads=2,
180+
num_hidden_layers=2,
181+
intermediate_size=32,
182+
),
183+
vision_config=vision_config_class(
184+
hidden_size=8,
185+
num_attention_heads=4,
186+
num_hidden_layers=2,
187+
intermediate_size=32,
188+
**vision_kwargs,
189+
),
190+
**kwargs,
191+
)
192+
model = model_class(config)
193+
push_to_hub(model, processor)

tests/slow/testing_constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
# TODO: push them under trl-org
1616
MODELS_TO_TEST = [
17-
"trl-internal-testing/tiny-random-LlamaForCausalLM",
18-
"HuggingFaceM4/tiny-random-MistralForCausalLM",
17+
"trl-internal-testing/tiny-LlamaForCausalLM-3.2",
18+
"trl-internal-testing/tiny-MistralForCausalLM-0.2",
1919
]
2020

2121
# We could have also not declared these variables but let's be verbose

tests/test_bco_trainer.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,30 +30,30 @@
3030

3131
class BCOTrainerTester(unittest.TestCase):
3232
def setUp(self):
33-
self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
33+
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
3434
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
3535
self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
3636
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
3737
self.tokenizer.pad_token = self.tokenizer.eos_token
3838

3939
# get t5 as seq2seq example:
40-
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
40+
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration"
4141
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
4242
self.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
4343
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
4444

4545
# get embedding model
46-
model_id = "facebook/bart-base"
46+
model_id = "trl-internal-testing/tiny-BartModel"
4747
self.embedding_model = AutoModel.from_pretrained(model_id)
4848
self.embedding_tokenizer = AutoTokenizer.from_pretrained(model_id)
4949

5050
@parameterized.expand(
5151
[
52-
["gpt2", True, True, "standard_unpaired_preference"],
53-
["gpt2", True, False, "standard_unpaired_preference"],
54-
["gpt2", False, True, "standard_unpaired_preference"],
55-
["gpt2", False, False, "standard_unpaired_preference"],
56-
["gpt2", True, True, "conversational_unpaired_preference"],
52+
("qwen", True, True, "standard_unpaired_preference"),
53+
("qwen", True, False, "standard_unpaired_preference"),
54+
("qwen", False, True, "standard_unpaired_preference"),
55+
("qwen", False, False, "standard_unpaired_preference"),
56+
("qwen", True, True, "conversational_unpaired_preference"),
5757
]
5858
)
5959
@require_sklearn
@@ -73,7 +73,7 @@ def test_bco_trainer(self, name, pre_compute, eval_dataset, config_name):
7373

7474
dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)
7575

76-
if name == "gpt2":
76+
if name == "qwen":
7777
model = self.model
7878
ref_model = self.ref_model
7979
tokenizer = self.tokenizer
@@ -160,9 +160,9 @@ def test_tokenize_and_process_tokens(self):
160160
self.assertListEqual(tokenized_dataset["prompt"], train_dataset["prompt"])
161161
self.assertListEqual(tokenized_dataset["completion"], train_dataset["completion"])
162162
self.assertListEqual(tokenized_dataset["label"], train_dataset["label"])
163-
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [5377, 11141])
164-
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1])
165-
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [318, 1365, 621, 8253, 13])
163+
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [31137])
164+
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1])
165+
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [374, 2664, 1091, 16965, 13])
166166
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1, 1, 1, 1])
167167

168168
fn_kwargs = {
@@ -178,15 +178,13 @@ def test_tokenize_and_process_tokens(self):
178178
self.assertListEqual(processed_dataset["prompt"], train_dataset["prompt"])
179179
self.assertListEqual(processed_dataset["completion"], train_dataset["completion"])
180180
self.assertListEqual(processed_dataset["label"], train_dataset["label"])
181-
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [50256, 5377, 11141])
182-
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1])
181+
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [31137])
182+
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1])
183183
self.assertListEqual(
184-
processed_dataset["completion_input_ids"][0], [50256, 5377, 11141, 318, 1365, 621, 8253, 13, 50256]
185-
)
186-
self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1, 1, 1])
187-
self.assertListEqual(
188-
processed_dataset["completion_labels"][0], [-100, -100, -100, 318, 1365, 621, 8253, 13, 50256]
184+
processed_dataset["completion_input_ids"][0], [31137, 374, 2664, 1091, 16965, 13, 151645]
189185
)
186+
self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1])
187+
self.assertListEqual(processed_dataset["completion_labels"][0], [-100, 374, 2664, 1091, 16965, 13, 151645])
190188

191189
@require_sklearn
192190
def test_bco_trainer_without_providing_ref_model(self):

tests/test_best_of_n_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class BestOfNSamplerTester(unittest.TestCase):
3131
Tests the BestOfNSampler class
3232
"""
3333

34-
ref_model_name = "trl-internal-testing/dummy-GPT2-correct-vocab"
34+
ref_model_name = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
3535
output_length_sampler = LengthSampler(2, 6)
3636
model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)
3737
tokenizer = AutoTokenizer.from_pretrained(ref_model_name)

tests/test_callbacks.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processi
6060

6161
class WinRateCallbackTester(unittest.TestCase):
6262
def setUp(self):
63-
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
64-
self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
65-
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
63+
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
64+
self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
65+
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
6666
self.tokenizer.pad_token = self.tokenizer.eos_token
6767
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
6868
dataset["train"] = dataset["train"].select(range(8))
@@ -219,8 +219,8 @@ def test_lora(self):
219219
@require_wandb
220220
class LogCompletionsCallbackTester(unittest.TestCase):
221221
def setUp(self):
222-
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
223-
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
222+
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
223+
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
224224
self.tokenizer.pad_token = self.tokenizer.eos_token
225225
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
226226
dataset["train"] = dataset["train"].select(range(8))
@@ -283,8 +283,8 @@ def test_basic(self):
283283
)
284284
class MergeModelCallbackTester(unittest.TestCase):
285285
def setUp(self):
286-
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM")
287-
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM")
286+
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
287+
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
288288
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
289289

290290
def test_callback(self):

tests/test_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class CLITester(unittest.TestCase):
2121
def test_sft_cli(self):
2222
try:
2323
subprocess.run(
24-
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name stanfordnlp/imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
24+
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name stanfordnlp/imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
2525
shell=True,
2626
check=True,
2727
)
@@ -32,7 +32,7 @@ def test_sft_cli(self):
3232
def test_dpo_cli(self):
3333
try:
3434
subprocess.run(
35-
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine",
35+
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine",
3636
shell=True,
3737
check=True,
3838
)

0 commit comments

Comments
 (0)