1616
1717import pytest
1818import torch
19+ import transformers
20+ from packaging import version
1921from parameterized import parameterized
2022from transformers import AutoModelForCausalLM , AutoModelForSeq2SeqLM , GenerationConfig
2123
@@ -63,6 +65,12 @@ def test_value_head(self):
6365 Test if the v-head is added to the model successfully
6466 """
6567 for model_name in self .all_model_names :
68+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
69+ transformers .__version__
70+ ) < version .parse ("4.58.0.dev0" ):
71+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
72+ continue
73+
6674 model = self .trl_model_class .from_pretrained (model_name )
6775 assert hasattr (model , "v_head" )
6876
@@ -71,6 +79,12 @@ def test_value_head_shape(self):
7179 Test if the v-head has the correct shape
7280 """
7381 for model_name in self .all_model_names :
82+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
83+ transformers .__version__
84+ ) < version .parse ("4.58.0.dev0" ):
85+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
86+ continue
87+
7488 model = self .trl_model_class .from_pretrained (model_name )
7589 assert model .v_head .summary .weight .shape [0 ] == 1
7690
@@ -80,6 +94,12 @@ def test_value_head_init_random(self):
8094 than zeros by default.
8195 """
8296 for model_name in self .all_model_names :
97+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
98+ transformers .__version__
99+ ) < version .parse ("4.58.0.dev0" ):
100+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
101+ continue
102+
83103 model = self .trl_model_class .from_pretrained (model_name )
84104 assert not torch .allclose (model .v_head .summary .bias , torch .zeros_like (model .v_head .summary .bias ))
85105
@@ -89,6 +109,12 @@ def test_value_head_not_str(self):
89109 `from_pretrained`.
90110 """
91111 for model_name in self .all_model_names :
112+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
113+ transformers .__version__
114+ ) < version .parse ("4.58.0.dev0" ):
115+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
116+ continue
117+
92118 pretrained_model = self .transformers_model_class .from_pretrained (model_name )
93119 model = self .trl_model_class .from_pretrained (pretrained_model )
94120 assert hasattr (model , "v_head" )
@@ -99,6 +125,12 @@ def test_from_save_trl(self):
99125 additional modules (e.g. v_head)
100126 """
101127 for model_name in self .all_model_names :
128+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
129+ transformers .__version__
130+ ) < version .parse ("4.58.0.dev0" ):
131+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
132+ continue
133+
102134 model = self .trl_model_class .from_pretrained (model_name )
103135
104136 model .save_pretrained (self .tmp_dir )
@@ -114,6 +146,12 @@ def test_from_save_trl_sharded(self):
114146 Test if the model can be saved and loaded from a directory and get the same weights - sharded case
115147 """
116148 for model_name in self .all_model_names :
149+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
150+ transformers .__version__
151+ ) < version .parse ("4.58.0.dev0" ):
152+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
153+ continue
154+
117155 model = self .trl_model_class .from_pretrained (model_name )
118156
119157 model .save_pretrained (self .tmp_dir )
@@ -129,6 +167,12 @@ def test_from_save_transformers_sharded(self):
129167 Test if the model can be saved and loaded using transformers and get the same weights - sharded case
130168 """
131169 for model_name in self .all_model_names :
170+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
171+ transformers .__version__
172+ ) < version .parse ("4.58.0.dev0" ):
173+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
174+ continue
175+
132176 transformers_model = self .trl_model_class .transformers_parent_class .from_pretrained (model_name )
133177
134178 trl_model = self .trl_model_class .from_pretrained (model_name )
@@ -150,6 +194,12 @@ def test_from_save_transformers(self):
150194 of the super class to check if the weights are the same.
151195 """
152196 for model_name in self .all_model_names :
197+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
198+ transformers .__version__
199+ ) < version .parse ("4.58.0.dev0" ):
200+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
201+ continue
202+
153203 transformers_model = self .trl_model_class .transformers_parent_class .from_pretrained (model_name )
154204
155205 trl_model = self .trl_model_class .from_pretrained (model_name )
@@ -200,6 +250,12 @@ def test_inference(self):
200250 EXPECTED_OUTPUT_SIZE = 3
201251
202252 for model_name in self .all_model_names :
253+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
254+ transformers .__version__
255+ ) < version .parse ("4.58.0.dev0" ):
256+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
257+ continue
258+
203259 model = self .trl_model_class .from_pretrained (model_name ).to (self .device )
204260 input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]], device = self .device )
205261 outputs = model (input_ids )
@@ -213,6 +269,12 @@ def test_dropout_config(self):
213269 Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head
214270 """
215271 for model_name in self .all_model_names :
272+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
273+ transformers .__version__
274+ ) < version .parse ("4.58.0.dev0" ):
275+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
276+ continue
277+
216278 pretrained_model = self .transformers_model_class .from_pretrained (model_name )
217279 pretrained_model .config .summary_dropout_prob = 0.5
218280 model = self .trl_model_class .from_pretrained (pretrained_model )
@@ -225,6 +287,11 @@ def test_dropout_kwargs(self):
225287 Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head
226288 """
227289 for model_name in self .all_model_names :
290+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
291+ transformers .__version__
292+ ) < version .parse ("4.58.0.dev0" ):
293+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
294+ continue
228295 v_head_kwargs = {"summary_dropout_prob" : 0.5 }
229296
230297 model = self .trl_model_class .from_pretrained (model_name , ** v_head_kwargs )
@@ -242,6 +309,12 @@ def test_generate(self, model_name):
242309 r"""
243310 Test if `generate` works for every model
244311 """
312+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
313+ transformers .__version__
314+ ) < version .parse ("4.58.0.dev0" ):
315+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
316+ pytest .xfail ("DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version" )
317+
245318 generation_config = GenerationConfig (max_new_tokens = 9 )
246319 model = self .trl_model_class .from_pretrained (model_name ).to (self .device )
247320 input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]], device = self .device )
@@ -256,6 +329,12 @@ def test_transformers_bf16_kwargs(self):
256329 run a dummy forward pass without any issue.
257330 """
258331 for model_name in self .all_model_names :
332+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
333+ transformers .__version__
334+ ) < version .parse ("4.58.0.dev0" ):
335+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
336+ continue
337+
259338 trl_model = self .trl_model_class .from_pretrained (model_name , dtype = torch .bfloat16 ).to (self .device )
260339
261340 lm_head_namings = ["lm_head" , "embed_out" , "output_layer" ]
@@ -276,6 +355,12 @@ def test_transformers_bf16_kwargs(self):
276355 @pytest .mark .skip (reason = "This test needs to be run manually due to HF token issue." )
277356 def test_push_to_hub (self ):
278357 for model_name in self .all_model_names :
358+ if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version .parse (
359+ transformers .__version__
360+ ) < version .parse ("4.58.0.dev0" ):
361+ # DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
362+ continue
363+
279364 model = AutoModelForCausalLMWithValueHead .from_pretrained (model_name )
280365 if "sharded" in model_name :
281366 model .push_to_hub (model_name + "-ppo" , use_auth_token = True , max_shard_size = "1MB" )
0 commit comments