File tree Expand file tree Collapse file tree 10 files changed +16
-10
lines changed
Expand file tree Collapse file tree 10 files changed +16
-10
lines changed Original file line number Diff line number Diff line change @@ -142,7 +142,7 @@ class BCOConfig(_BaseConfig):
142142 "needed."
143143 },
144144 )
145- model_init_kwargs : dict [str , Any ] | None = field (
145+ model_init_kwargs : dict [str , Any ] | str | None = field (
146146 default = None ,
147147 metadata = {
148148 "help" : "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
Original file line number Diff line number Diff line change @@ -163,7 +163,7 @@ class CPOConfig(_BaseConfig):
163163 default = None ,
164164 metadata = {"help" : "Whether the model is an encoder-decoder model." },
165165 )
166- model_init_kwargs : dict [str , Any ] | None = field (
166+ model_init_kwargs : dict [str , Any ] | str | None = field (
167167 default = None ,
168168 metadata = {
169169 "help" : "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
Original file line number Diff line number Diff line change @@ -132,7 +132,7 @@ class KTOConfig(_BaseConfig):
132132 "This is useful when training without the reference model to reduce the total GPU memory needed."
133133 },
134134 )
135- model_init_kwargs : dict [str , Any ] | None = field (
135+ model_init_kwargs : dict [str , Any ] | str | None = field (
136136 default = None ,
137137 metadata = {
138138 "help" : "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
Original file line number Diff line number Diff line change 1616from dataclasses import dataclass , field
1717from typing import Any
1818
19+ from transformers import TrainingArguments
20+
1921from ...trainer .base_config import _BaseConfig
2022
2123
@@ -159,6 +161,8 @@ class may differ from those in [`~transformers.TrainingArguments`].
159161 > - `learning_rate`: Defaults to `5e-7` instead of `5e-5`.
160162 """
161163
164+ _VALID_DICT_FIELDS = TrainingArguments ._VALID_DICT_FIELDS + ["model_init_kwargs" ]
165+
162166 # Parameters whose default values are overridden from TrainingArguments
163167 learning_rate : float = field (
164168 default = 5e-7 ,
@@ -361,7 +365,7 @@ class may differ from those in [`~transformers.TrainingArguments`].
361365 "is not compatible with vLLM generation."
362366 },
363367 )
364- model_init_kwargs : dict [str , Any ] | None = field (
368+ model_init_kwargs : dict [str , Any ] | str | None = field (
365369 default = None ,
366370 metadata = {
367371 "help" : "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
Original file line number Diff line number Diff line change @@ -123,7 +123,7 @@ class ORPOConfig(_BaseConfig):
123123 "argument, you need to specify if the model returned by the callable is an encoder-decoder model."
124124 },
125125 )
126- model_init_kwargs : dict [str , Any ] | None = field (
126+ model_init_kwargs : dict [str , Any ] | str | None = field (
127127 default = None ,
128128 metadata = {
129129 "help" : "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
Original file line number Diff line number Diff line change @@ -138,7 +138,7 @@ class DPOConfig(_BaseConfig):
138138 )
139139
140140 # Parameters that control the model
141- model_init_kwargs : dict [str , Any ] | None = field (
141+ model_init_kwargs : dict [str , Any ] | str | None = field (
142142 default = None ,
143143 metadata = {
144144 "help" : "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
Original file line number Diff line number Diff line change 1313# limitations under the License.
1414
1515from dataclasses import dataclass , field
16+ from typing import Any
1617
1718from transformers import TrainingArguments
1819
@@ -332,7 +333,7 @@ class GRPOConfig(_BaseConfig):
332333 )
333334
334335 # Parameters that control the model and reference model
335- model_init_kwargs : dict | str | None = field (
336+ model_init_kwargs : dict [ str , Any ] | str | None = field (
336337 default = None ,
337338 metadata = {
338339 "help" : "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
Original file line number Diff line number Diff line change @@ -91,7 +91,7 @@ class may differ from those in [`~transformers.TrainingArguments`].
9191 )
9292
9393 # Parameters that control the model
94- model_init_kwargs : dict [str , Any ] | None = field (
94+ model_init_kwargs : dict [str , Any ] | str | None = field (
9595 default = None ,
9696 metadata = {
9797 "help" : "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
Original file line number Diff line number Diff line change 1313# limitations under the License.
1414
1515from dataclasses import dataclass , field
16+ from typing import Any
1617
1718from transformers import TrainingArguments
1819
@@ -224,7 +225,7 @@ class RLOOConfig(_BaseConfig):
224225 )
225226
226227 # Parameters that control the model and reference model
227- model_init_kwargs : dict | str | None = field (
228+ model_init_kwargs : dict [ str , Any ] | str | None = field (
228229 default = None ,
229230 metadata = {
230231 "help" : "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
Original file line number Diff line number Diff line change @@ -121,7 +121,7 @@ class SFTConfig(_BaseConfig):
121121 )
122122
123123 # Parameters that control the model
124- model_init_kwargs : dict [str , Any ] | None = field (
124+ model_init_kwargs : dict [str , Any ] | str | None = field (
125125 default = None ,
126126 metadata = {
127127 "help" : "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
You can’t perform that action at this time.
0 commit comments