Skip to content

Commit 1850da5

Browse files
Fix type for model_init_kwargs when passed as CLI JSON string (#5230)
1 parent 84ca123 commit 1850da5

File tree

10 files changed

+16
-10
lines changed

10 files changed

+16
-10
lines changed

trl/experimental/bco/bco_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class BCOConfig(_BaseConfig):
142142
"needed."
143143
},
144144
)
145-
model_init_kwargs: dict[str, Any] | None = field(
145+
model_init_kwargs: dict[str, Any] | str | None = field(
146146
default=None,
147147
metadata={
148148
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "

trl/experimental/cpo/cpo_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class CPOConfig(_BaseConfig):
163163
default=None,
164164
metadata={"help": "Whether the model is an encoder-decoder model."},
165165
)
166-
model_init_kwargs: dict[str, Any] | None = field(
166+
model_init_kwargs: dict[str, Any] | str | None = field(
167167
default=None,
168168
metadata={
169169
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "

trl/experimental/kto/kto_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class KTOConfig(_BaseConfig):
132132
"This is useful when training without the reference model to reduce the total GPU memory needed."
133133
},
134134
)
135-
model_init_kwargs: dict[str, Any] | None = field(
135+
model_init_kwargs: dict[str, Any] | str | None = field(
136136
default=None,
137137
metadata={
138138
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "

trl/experimental/online_dpo/online_dpo_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from dataclasses import dataclass, field
1717
from typing import Any
1818

19+
from transformers import TrainingArguments
20+
1921
from ...trainer.base_config import _BaseConfig
2022

2123

@@ -159,6 +161,8 @@ class may differ from those in [`~transformers.TrainingArguments`].
159161
> - `learning_rate`: Defaults to `5e-7` instead of `5e-5`.
160162
"""
161163

164+
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
165+
162166
# Parameters whose default values are overridden from TrainingArguments
163167
learning_rate: float = field(
164168
default=5e-7,
@@ -361,7 +365,7 @@ class may differ from those in [`~transformers.TrainingArguments`].
361365
"is not compatible with vLLM generation."
362366
},
363367
)
364-
model_init_kwargs: dict[str, Any] | None = field(
368+
model_init_kwargs: dict[str, Any] | str | None = field(
365369
default=None,
366370
metadata={
367371
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "

trl/experimental/orpo/orpo_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class ORPOConfig(_BaseConfig):
123123
"argument, you need to specify if the model returned by the callable is an encoder-decoder model."
124124
},
125125
)
126-
model_init_kwargs: dict[str, Any] | None = field(
126+
model_init_kwargs: dict[str, Any] | str | None = field(
127127
default=None,
128128
metadata={
129129
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "

trl/trainer/dpo_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class DPOConfig(_BaseConfig):
138138
)
139139

140140
# Parameters that control the model
141-
model_init_kwargs: dict[str, Any] | None = field(
141+
model_init_kwargs: dict[str, Any] | str | None = field(
142142
default=None,
143143
metadata={
144144
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "

trl/trainer/grpo_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass, field
16+
from typing import Any
1617

1718
from transformers import TrainingArguments
1819

@@ -332,7 +333,7 @@ class GRPOConfig(_BaseConfig):
332333
)
333334

334335
# Parameters that control the model and reference model
335-
model_init_kwargs: dict | str | None = field(
336+
model_init_kwargs: dict[str, Any] | str | None = field(
336337
default=None,
337338
metadata={
338339
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "

trl/trainer/reward_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class may differ from those in [`~transformers.TrainingArguments`].
9191
)
9292

9393
# Parameters that control the model
94-
model_init_kwargs: dict[str, Any] | None = field(
94+
model_init_kwargs: dict[str, Any] | str | None = field(
9595
default=None,
9696
metadata={
9797
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "

trl/trainer/rloo_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass, field
16+
from typing import Any
1617

1718
from transformers import TrainingArguments
1819

@@ -224,7 +225,7 @@ class RLOOConfig(_BaseConfig):
224225
)
225226

226227
# Parameters that control the model and reference model
227-
model_init_kwargs: dict | str | None = field(
228+
model_init_kwargs: dict[str, Any] | str | None = field(
228229
default=None,
229230
metadata={
230231
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "

trl/trainer/sft_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class SFTConfig(_BaseConfig):
121121
)
122122

123123
# Parameters that control the model
124-
model_init_kwargs: dict[str, Any] | None = field(
124+
model_init_kwargs: dict[str, Any] | str | None = field(
125125
default=None,
126126
metadata={
127127
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "

0 commit comments

Comments
 (0)