File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -361,11 +361,12 @@ def train(
361361 # from our object directly. In the future, we should consider renaming this class and / or
362362 # not adding things that are not directly used by the trainer instance to it.
363363
364- transformer_train_arg_fields = [x .name for x in dataclasses .fields (SFTConfig )]
364+ # To filter out fields that are not defined as init (eg. _n_gpu)
365+ transformer_train_arg_fields = [
366+ x .name for x in dataclasses .fields (SFTConfig ) if x .init
367+ ]
365368 transformer_kwargs = {
366- k : v
367- for k , v in train_args .to_dict ().items ()
368- if k in transformer_train_arg_fields
369+ k : v for k , v in vars (train_args ).items () if k in transformer_train_arg_fields
369370 }
370371
371372 additional_args = {
You can’t perform that action at this time.
0 commit comments