@@ -476,7 +476,7 @@ class TrainingArguments:
476476 When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
477477 stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step
478478 can take a long time) but will not yield the same results as the interrupted training would have.
479- fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `'' `):
479+ fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `[] `):
480480 Use PyTorch Distributed Parallel Training (in distributed training only).
481481
482482 A list of options along the following:
@@ -738,11 +738,10 @@ class TrainingArguments:
738738 Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
739739
740740 This flag is experimental and subject to change in future releases.
741- include_tokens_per_second (`bool`, *optional*):
741+ include_tokens_per_second (`bool`, *optional*, defaults to `False` ):
742742 Whether or not to compute the number of tokens per second per device for training speed metrics.
743743
744744 This will iterate over the entire training dataloader once beforehand,
745-
746745 and will slow down the entire process.
747746
748747 include_num_input_tokens_seen (`bool`, *optional*):
@@ -761,7 +760,7 @@ class TrainingArguments:
761760 See GaLore implementation (https://github.com/jiaweizzhao/GaLore) and APOLLO implementation (https://github.com/zhuhanqing/APOLLO) for more details.
762761 You need to make sure to pass a valid GaLore or APOLLO optimizer, e.g., one of: "apollo_adamw", "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules only.
763762
764- batch_eval_metrics (`Optional[ bool]` , defaults to `False`):
763+ batch_eval_metrics (`bool`, *optional* , defaults to `False`):
765764 If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics
766765 rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
767766 that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
@@ -845,7 +844,7 @@ class TrainingArguments:
845844 metadata = {"help" : "Number of predictions steps to accumulate before moving the tensors to the CPU." },
846845 )
847846
848- eval_delay : Optional [ float ] = field (
847+ eval_delay : float = field (
849848 default = 0 ,
850849 metadata = {
851850 "help" : (
@@ -880,7 +879,7 @@ class TrainingArguments:
880879 default = "linear" ,
881880 metadata = {"help" : "The scheduler type to use." },
882881 )
883- lr_scheduler_kwargs : Optional [ Union [dict [str , Any ], str ] ] = field (
882+ lr_scheduler_kwargs : Union [dict [str , Any ], str ] = field (
884883 default_factory = dict ,
885884 metadata = {
886885 "help" : (
@@ -963,7 +962,7 @@ class TrainingArguments:
963962 )
964963 },
965964 )
966- save_safetensors : Optional [ bool ] = field (
965+ save_safetensors : bool = field (
967966 default = True ,
968967 metadata = {
969968 "help" : "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
@@ -1117,13 +1116,13 @@ class TrainingArguments:
11171116 default = None , metadata = {"help" : "Whether or not to disable the tqdm progress bars." }
11181117 )
11191118
1120- remove_unused_columns : Optional [ bool ] = field (
1119+ remove_unused_columns : bool = field (
11211120 default = True , metadata = {"help" : "Remove columns not required by the model when using an nlp.Dataset." }
11221121 )
11231122 label_names : Optional [list [str ]] = field (
11241123 default = None , metadata = {"help" : "The list of keys in your dictionary of inputs that correspond to the labels." }
11251124 )
1126- load_best_model_at_end : Optional [ bool ] = field (
1125+ load_best_model_at_end : bool = field (
11271126 default = False ,
11281127 metadata = {
11291128 "help" : (
@@ -1147,8 +1146,8 @@ class TrainingArguments:
11471146 )
11481147 },
11491148 )
1150- fsdp : Optional [ Union [list [FSDPOption ], str ] ] = field (
1151- default = "" ,
1149+ fsdp : Union [list [FSDPOption ], str , bool ] = field (
1150+ default_factory = list ,
11521151 metadata = {
11531152 "help" : (
11541153 "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training"
@@ -1209,7 +1208,7 @@ class TrainingArguments:
12091208 default = False ,
12101209 metadata = {"help" : "Whether or not to group samples of roughly the same length together when batching." },
12111210 )
1212- length_column_name : Optional [ str ] = field (
1211+ length_column_name : str = field (
12131212 default = "length" ,
12141213 metadata = {"help" : "Column name with precomputed lengths to use when grouping by length." },
12151214 )
@@ -1338,7 +1337,7 @@ class TrainingArguments:
13381337 )
13391338 },
13401339 )
1341- ray_scope : Optional [ str ] = field (
1340+ ray_scope : str = field (
13421341 default = "last" ,
13431342 metadata = {
13441343 "help" : (
@@ -1373,12 +1372,12 @@ class TrainingArguments:
13731372 },
13741373 )
13751374
1376- include_tokens_per_second : Optional [ bool ] = field (
1375+ include_tokens_per_second : bool = field (
13771376 default = False ,
13781377 metadata = {"help" : "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)." },
13791378 )
13801379
1381- include_num_input_tokens_seen : Optional [ Union [str , bool ] ] = field (
1380+ include_num_input_tokens_seen : Union [str , bool ] = field (
13821381 default = False ,
13831382 metadata = {
13841383 "help" : (
@@ -1415,7 +1414,7 @@ class TrainingArguments:
14151414 },
14161415 )
14171416
1418- use_liger_kernel : Optional [ bool ] = field (
1417+ use_liger_kernel : bool = field (
14191418 default = False ,
14201419 metadata = {"help" : "Whether or not to enable the Liger Kernel for model training." },
14211420 )
@@ -1433,14 +1432,14 @@ class TrainingArguments:
14331432 },
14341433 )
14351434
1436- eval_use_gather_object : Optional [ bool ] = field (
1435+ eval_use_gather_object : bool = field (
14371436 default = False ,
14381437 metadata = {
14391438 "help" : "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices."
14401439 },
14411440 )
14421441
1443- average_tokens_across_devices : Optional [ bool ] = field (
1442+ average_tokens_across_devices : bool = field (
14441443 default = True ,
14451444 metadata = {
14461445 "help" : "Whether or not to average tokens across devices. If enabled, will use all_reduce to "
0 commit comments