Skip to content

Commit a61fc6a

Browse files
authored
Fix typing of train_args (#41142)
* Fix typing Signed-off-by: Yuanyuan Chen <[email protected]> * Fix fsdp typing Signed-off-by: Yuanyuan Chen <[email protected]> --------- Signed-off-by: Yuanyuan Chen <[email protected]>
1 parent 919a484 commit a61fc6a

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

src/transformers/training_args.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ class TrainingArguments:
476476
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
477477
stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step
478478
can take a long time) but will not yield the same results as the interrupted training would have.
479-
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`):
479+
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `[]`):
480480
Use PyTorch Distributed Parallel Training (in distributed training only).
481481
482482
A list of options along the following:
@@ -738,11 +738,10 @@ class TrainingArguments:
738738
Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
739739
740740
This flag is experimental and subject to change in future releases.
741-
include_tokens_per_second (`bool`, *optional*):
741+
include_tokens_per_second (`bool`, *optional*, defaults to `False`):
742742
Whether or not to compute the number of tokens per second per device for training speed metrics.
743743
744744
This will iterate over the entire training dataloader once beforehand,
745-
746745
and will slow down the entire process.
747746
748747
include_num_input_tokens_seen (`bool`, *optional*):
@@ -761,7 +760,7 @@ class TrainingArguments:
761760
See GaLore implementation (https://github.com/jiaweizzhao/GaLore) and APOLLO implementation (https://github.com/zhuhanqing/APOLLO) for more details.
762761
You need to make sure to pass a valid GaLore or APOLLO optimizer, e.g., one of: "apollo_adamw", "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules only.
763762
764-
batch_eval_metrics (`Optional[bool]`, defaults to `False`):
763+
batch_eval_metrics (`bool`, *optional*, defaults to `False`):
765764
If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics
766765
rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
767766
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
@@ -845,7 +844,7 @@ class TrainingArguments:
845844
metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."},
846845
)
847846

848-
eval_delay: Optional[float] = field(
847+
eval_delay: float = field(
849848
default=0,
850849
metadata={
851850
"help": (
@@ -880,7 +879,7 @@ class TrainingArguments:
880879
default="linear",
881880
metadata={"help": "The scheduler type to use."},
882881
)
883-
lr_scheduler_kwargs: Optional[Union[dict[str, Any], str]] = field(
882+
lr_scheduler_kwargs: Union[dict[str, Any], str] = field(
884883
default_factory=dict,
885884
metadata={
886885
"help": (
@@ -963,7 +962,7 @@ class TrainingArguments:
963962
)
964963
},
965964
)
966-
save_safetensors: Optional[bool] = field(
965+
save_safetensors: bool = field(
967966
default=True,
968967
metadata={
969968
"help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
@@ -1117,13 +1116,13 @@ class TrainingArguments:
11171116
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
11181117
)
11191118

1120-
remove_unused_columns: Optional[bool] = field(
1119+
remove_unused_columns: bool = field(
11211120
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
11221121
)
11231122
label_names: Optional[list[str]] = field(
11241123
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
11251124
)
1126-
load_best_model_at_end: Optional[bool] = field(
1125+
load_best_model_at_end: bool = field(
11271126
default=False,
11281127
metadata={
11291128
"help": (
@@ -1147,8 +1146,8 @@ class TrainingArguments:
11471146
)
11481147
},
11491148
)
1150-
fsdp: Optional[Union[list[FSDPOption], str]] = field(
1151-
default="",
1149+
fsdp: Union[list[FSDPOption], str, bool] = field(
1150+
default_factory=list,
11521151
metadata={
11531152
"help": (
11541153
"Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training"
@@ -1209,7 +1208,7 @@ class TrainingArguments:
12091208
default=False,
12101209
metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
12111210
)
1212-
length_column_name: Optional[str] = field(
1211+
length_column_name: str = field(
12131212
default="length",
12141213
metadata={"help": "Column name with precomputed lengths to use when grouping by length."},
12151214
)
@@ -1338,7 +1337,7 @@ class TrainingArguments:
13381337
)
13391338
},
13401339
)
1341-
ray_scope: Optional[str] = field(
1340+
ray_scope: str = field(
13421341
default="last",
13431342
metadata={
13441343
"help": (
@@ -1373,12 +1372,12 @@ class TrainingArguments:
13731372
},
13741373
)
13751374

1376-
include_tokens_per_second: Optional[bool] = field(
1375+
include_tokens_per_second: bool = field(
13771376
default=False,
13781377
metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
13791378
)
13801379

1381-
include_num_input_tokens_seen: Optional[Union[str, bool]] = field(
1380+
include_num_input_tokens_seen: Union[str, bool] = field(
13821381
default=False,
13831382
metadata={
13841383
"help": (
@@ -1415,7 +1414,7 @@ class TrainingArguments:
14151414
},
14161415
)
14171416

1418-
use_liger_kernel: Optional[bool] = field(
1417+
use_liger_kernel: bool = field(
14191418
default=False,
14201419
metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
14211420
)
@@ -1433,14 +1432,14 @@ class TrainingArguments:
14331432
},
14341433
)
14351434

1436-
eval_use_gather_object: Optional[bool] = field(
1435+
eval_use_gather_object: bool = field(
14371436
default=False,
14381437
metadata={
14391438
"help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices."
14401439
},
14411440
)
14421441

1443-
average_tokens_across_devices: Optional[bool] = field(
1442+
average_tokens_across_devices: bool = field(
14441443
default=True,
14451444
metadata={
14461445
"help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to "

0 commit comments

Comments
 (0)