Skip to content

Commit 8265800

Browse files
authored
Fix trl-internal-testing/tiny-DbrxForCausalLM (#4213)
1 parent 65eb45c commit 8265800

File tree

2 files changed

+96
-1
lines changed

2 files changed

+96
-1
lines changed

scripts/generate_tiny_models.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def init_weights_tiny_model(model):
155155
for model_id, config_class, model_class, suffix in [
156156
("bigscience/bloomz-560m", BloomConfig, BloomForCausalLM, None),
157157
("CohereForAI/aya-expanse-8b", CohereConfig, CohereForCausalLM, None),
158-
("databricks/dbrx-instruct", DbrxConfig, DbrxForCausalLM, None),
159158
("deepseek-ai/DeepSeek-R1", DeepseekV3Config, DeepseekV3ForCausalLM, None),
160159
# It's important to have R1-0528 as it doesn't have the same chat template
161160
("deepseek-ai/DeepSeek-R1-0528", DeepseekV3Config, DeepseekV3ForCausalLM, "0528"),
@@ -209,6 +208,17 @@ def init_weights_tiny_model(model):
209208
init_weights_tiny_model(model)
210209
push_to_hub(model, tokenizer, "tiny", suffix)
211210

211+
# Special case for databricks/dbrx-instruct as it requires specific changes in the config
212+
model_id = "databricks/dbrx-instruct"
213+
tokenizer = AutoTokenizer.from_pretrained(model_id)
214+
config = DbrxConfig.from_pretrained(model_id, n_layers=2, n_heads=16, d_model=24)
215+
# transformers mistakenly ignores ffn_config keys when loading from pretrained. We need to set them manually after
216+
# loading the config
217+
config.ffn_config.ffn_hidden_size = 24
218+
config.ffn_config.hidden_size = 24
219+
model = DbrxForCausalLM(config).to(dtype=torch.bfloat16)
220+
init_weights_tiny_model(model)
221+
push_to_hub(model, tokenizer, "tiny")
212222

213223
# Two slightly bigger models, required for vLLM testing
214224
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct")

tests/test_modeling_value_head.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import pytest
1818
import torch
19+
import transformers
20+
from packaging import version
1921
from parameterized import parameterized
2022
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig
2123

@@ -63,6 +65,12 @@ def test_value_head(self):
6365
Test if the v-head is added to the model successfully
6466
"""
6567
for model_name in self.all_model_names:
68+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
69+
transformers.__version__
70+
) < version.parse("4.58.0.dev0"):
71+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
72+
continue
73+
6674
model = self.trl_model_class.from_pretrained(model_name)
6775
assert hasattr(model, "v_head")
6876

@@ -71,6 +79,12 @@ def test_value_head_shape(self):
7179
Test if the v-head has the correct shape
7280
"""
7381
for model_name in self.all_model_names:
82+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
83+
transformers.__version__
84+
) < version.parse("4.58.0.dev0"):
85+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
86+
continue
87+
7488
model = self.trl_model_class.from_pretrained(model_name)
7589
assert model.v_head.summary.weight.shape[0] == 1
7690

@@ -80,6 +94,12 @@ def test_value_head_init_random(self):
8094
than zeros by default.
8195
"""
8296
for model_name in self.all_model_names:
97+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
98+
transformers.__version__
99+
) < version.parse("4.58.0.dev0"):
100+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
101+
continue
102+
83103
model = self.trl_model_class.from_pretrained(model_name)
84104
assert not torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))
85105

@@ -89,6 +109,12 @@ def test_value_head_not_str(self):
89109
`from_pretrained`.
90110
"""
91111
for model_name in self.all_model_names:
112+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
113+
transformers.__version__
114+
) < version.parse("4.58.0.dev0"):
115+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
116+
continue
117+
92118
pretrained_model = self.transformers_model_class.from_pretrained(model_name)
93119
model = self.trl_model_class.from_pretrained(pretrained_model)
94120
assert hasattr(model, "v_head")
@@ -99,6 +125,12 @@ def test_from_save_trl(self):
99125
additional modules (e.g. v_head)
100126
"""
101127
for model_name in self.all_model_names:
128+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
129+
transformers.__version__
130+
) < version.parse("4.58.0.dev0"):
131+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
132+
continue
133+
102134
model = self.trl_model_class.from_pretrained(model_name)
103135

104136
model.save_pretrained(self.tmp_dir)
@@ -114,6 +146,12 @@ def test_from_save_trl_sharded(self):
114146
Test if the model can be saved and loaded from a directory and get the same weights - sharded case
115147
"""
116148
for model_name in self.all_model_names:
149+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
150+
transformers.__version__
151+
) < version.parse("4.58.0.dev0"):
152+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
153+
continue
154+
117155
model = self.trl_model_class.from_pretrained(model_name)
118156

119157
model.save_pretrained(self.tmp_dir)
@@ -129,6 +167,12 @@ def test_from_save_transformers_sharded(self):
129167
Test if the model can be saved and loaded using transformers and get the same weights - sharded case
130168
"""
131169
for model_name in self.all_model_names:
170+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
171+
transformers.__version__
172+
) < version.parse("4.58.0.dev0"):
173+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
174+
continue
175+
132176
transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name)
133177

134178
trl_model = self.trl_model_class.from_pretrained(model_name)
@@ -150,6 +194,12 @@ def test_from_save_transformers(self):
150194
of the super class to check if the weights are the same.
151195
"""
152196
for model_name in self.all_model_names:
197+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
198+
transformers.__version__
199+
) < version.parse("4.58.0.dev0"):
200+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
201+
continue
202+
153203
transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name)
154204

155205
trl_model = self.trl_model_class.from_pretrained(model_name)
@@ -200,6 +250,12 @@ def test_inference(self):
200250
EXPECTED_OUTPUT_SIZE = 3
201251

202252
for model_name in self.all_model_names:
253+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
254+
transformers.__version__
255+
) < version.parse("4.58.0.dev0"):
256+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
257+
continue
258+
203259
model = self.trl_model_class.from_pretrained(model_name).to(self.device)
204260
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
205261
outputs = model(input_ids)
@@ -213,6 +269,12 @@ def test_dropout_config(self):
213269
Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head
214270
"""
215271
for model_name in self.all_model_names:
272+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
273+
transformers.__version__
274+
) < version.parse("4.58.0.dev0"):
275+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
276+
continue
277+
216278
pretrained_model = self.transformers_model_class.from_pretrained(model_name)
217279
pretrained_model.config.summary_dropout_prob = 0.5
218280
model = self.trl_model_class.from_pretrained(pretrained_model)
@@ -225,6 +287,11 @@ def test_dropout_kwargs(self):
225287
Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head
226288
"""
227289
for model_name in self.all_model_names:
290+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
291+
transformers.__version__
292+
) < version.parse("4.58.0.dev0"):
293+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
294+
continue
228295
v_head_kwargs = {"summary_dropout_prob": 0.5}
229296

230297
model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs)
@@ -242,6 +309,12 @@ def test_generate(self, model_name):
242309
r"""
243310
Test if `generate` works for every model
244311
"""
312+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
313+
transformers.__version__
314+
) < version.parse("4.58.0.dev0"):
315+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
316+
pytest.xfail("DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version")
317+
245318
generation_config = GenerationConfig(max_new_tokens=9)
246319
model = self.trl_model_class.from_pretrained(model_name).to(self.device)
247320
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
@@ -256,6 +329,12 @@ def test_transformers_bf16_kwargs(self):
256329
run a dummy forward pass without any issue.
257330
"""
258331
for model_name in self.all_model_names:
332+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
333+
transformers.__version__
334+
) < version.parse("4.58.0.dev0"):
335+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
336+
continue
337+
259338
trl_model = self.trl_model_class.from_pretrained(model_name, dtype=torch.bfloat16).to(self.device)
260339

261340
lm_head_namings = ["lm_head", "embed_out", "output_layer"]
@@ -276,6 +355,12 @@ def test_transformers_bf16_kwargs(self):
276355
@pytest.mark.skip(reason="This test needs to be run manually due to HF token issue.")
277356
def test_push_to_hub(self):
278357
for model_name in self.all_model_names:
358+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
359+
transformers.__version__
360+
) < version.parse("4.58.0.dev0"):
361+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
362+
continue
363+
279364
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
280365
if "sharded" in model_name:
281366
model.push_to_hub(model_name + "-ppo", use_auth_token=True, max_shard_size="1MB")

0 commit comments

Comments
 (0)