Skip to content

Commit 94e4135

Browse files
authored
🔓 Remove lm_head check in AutoModelForCausalLMWithValueHead (huggingface#2398)
* Remove lm_head check in `AutoModelForCausalLMWithValueHead` * Style * Remove test
1 parent ac26778 commit 94e4135

File tree

2 files changed

+3
-19
lines changed

2 files changed

+3
-19
lines changed

tests/test_modeling_value_head.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,6 @@ def test_generate(self, model_name):
265265
# Just check if the generation works
266266
_ = model.generate(input_ids, generation_config=generation_config)
267267

268-
def test_raise_error_not_causallm(self):
269-
# Test with a model without a LM head
270-
model_id = "trl-internal-testing/tiny-GPT2LMHeadModel"
271-
# This should raise a ValueError
272-
with self.assertRaises(ValueError):
273-
pretrained_model = AutoModelForCausalLM.from_pretrained(model_id)
274-
_ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer)
275-
276268
def test_transformers_bf16_kwargs(self):
277269
r"""
278270
Test if the transformers kwargs are correctly passed
@@ -283,10 +275,11 @@ def test_transformers_bf16_kwargs(self):
283275
for model_name in self.all_model_names:
284276
trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16)
285277

286-
lm_head_namings = self.trl_model_class.lm_head_namings
278+
lm_head_namings = ["lm_head", "embed_out", "output_layer"]
287279

288280
self.assertTrue(
289-
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
281+
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings),
282+
"Can't test the model because it doesn't have any of the expected lm_head namings",
290283
)
291284

292285
for lm_head_naming in lm_head_namings:

trl/models/modeling_value_head.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
6969
Class attributes:
7070
- **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This
7171
should be set to `transformers.AutoModelForCausalLM` for this class.
72-
- **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the
73-
wrapped model. This is set to `("lm_head", "embed_out", "output_layer")` for this class but can be changed
74-
for other models in the future
7572
- **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported
7673
by the `ValueHead` class. Currently, the supported args are:
7774
- **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the
@@ -86,7 +83,6 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
8683
"""
8784

8885
transformers_parent_class = AutoModelForCausalLM
89-
lm_head_namings = ["lm_head", "embed_out", "output_layer"]
9086
supported_args = (
9187
"summary_dropout_prob",
9288
"v_head_initializer_range",
@@ -106,12 +102,7 @@ def __init__(self, pretrained_model, **kwargs):
106102
"""
107103
super().__init__(pretrained_model, **kwargs)
108104
v_head_kwargs, _, _ = self._split_kwargs(kwargs)
109-
110-
if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
111-
raise ValueError("The model does not have a language model head, please use a model that has one.")
112-
113105
self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
114-
115106
self._init_weights(**v_head_kwargs)
116107

117108
def _init_weights(self, **kwargs):

0 commit comments

Comments
 (0)