@@ -98,9 +98,6 @@ def main():
98
98
99
99
####### A. Preparation
100
100
kwargs_handlers = [InitProcessGroupKwargs (timeout = timedelta (minutes = 60 ))]
101
- if training_args .torch_compile :
102
- # TODO(YL): add more compile modes?
103
- kwargs_handlers .append (TorchDynamoPlugin (backend = "inductor" , mode = "default" )) # reduce-overhead
104
101
105
102
accelerator = Accelerator (
106
103
gradient_accumulation_steps = training_args .gradient_accumulation_steps ,
@@ -129,6 +126,7 @@ def main():
129
126
"adam_beta2" : training_args .adam_beta2 ,
130
127
"temperature" : model_args .temperature ,
131
128
},
129
+ init_kwargs = {"wandb" : {"name" : data_args .wandb_run_name }} if data_args .wandb_run_name else None ,
132
130
)
133
131
134
132
# Detecting last checkpoint and eventually continue from last checkpoint
@@ -538,7 +536,7 @@ def is_audio_in_length_range(length):
538
536
logger .info (f"Dataset saved at { data_args .save_to_disk } " )
539
537
540
538
audio_max_length = None
541
- if training_args . torch_compile :
539
+ if padding == "max_length" :
542
540
audio_max_length = max (vectorized_datasets ["train" ]["target_length" ])
543
541
with accelerator .main_process_first ():
544
542
max_sample = vectorized_datasets ["train" ].filter (
@@ -548,6 +546,18 @@ def is_audio_in_length_range(length):
548
546
)
549
547
audio_max_length = torch .tensor (max_sample [0 ]["labels" ]).shape [1 ]
550
548
549
+ if training_args .group_by_length :
550
+ # apply a simple heuristic to take into account audio and text lengths
551
+ def add_target_lengths (target_length , prompt , description ):
552
+ return {"target_length" : target_length + len (prompt ) + len (description )}
553
+
554
+ with accelerator .main_process_first ():
555
+ vectorized_datasets = vectorized_datasets .map (
556
+ add_target_lengths ,
557
+ num_proc = num_workers ,
558
+ input_columns = ["target_length" , "prompt_input_ids" , "input_ids" ],
559
+ )
560
+
551
561
# for large datasets it is advised to run the preprocessing on a
552
562
# single machine first with ``args.preprocessing_only`` since there will mostly likely
553
563
# be a timeout when running the script in distributed mode.
0 commit comments