Skip to content

Commit 9232a47

Browse files
authored
Merge pull request #53 from ylacombe/nits-improvements
[Training] Small nits
2 parents 5518cc2 + a0bc9e7 commit 9232a47

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

training/arguments.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ class DataTrainingArguments:
218218
metadata={
219219
"help": (
220220
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
221-
"Also, used to set maximum desription token length if `pad_to_max_length=True`."
221+
"Also, used to set maximum description token length if `pad_to_max_length=True`."
222222
)
223223
},
224224
)
@@ -277,6 +277,12 @@ class DataTrainingArguments:
277277
default="parler-speech",
278278
metadata={"help": "The name of the wandb project."},
279279
)
280+
wandb_run_name: str = field(
281+
default=None,
282+
metadata={
283+
"help": "If specified, the name of the run. If not specified, wandb will give a random name to this run."
284+
},
285+
)
280286
save_to_disk: str = field(
281287
default=None,
282288
metadata={

training/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
3131
audios = [feature[self.audio_column_name]["array"] for feature in features]
3232
len_audio = [len(audio) for audio in audios]
3333

34-
# since resampling has already been performed in the 'load_multiple_datasets' function,
34+
# since resampling has already been performed in the 'load_multiple_datasets' function,
3535
# a fixed sampling_rate(44100hz) is passed to the feature_extractor.
3636
sampling_rate = self.feature_extractor.sampling_rate
3737
batch = self.feature_extractor(

training/run_parler_tts_training.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@ def main():
9898

9999
####### A. Preparation
100100
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
101-
if training_args.torch_compile:
102-
# TODO(YL): add more compile modes?
103-
kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) # reduce-overhead
104101

105102
accelerator = Accelerator(
106103
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
@@ -129,6 +126,7 @@ def main():
129126
"adam_beta2": training_args.adam_beta2,
130127
"temperature": model_args.temperature,
131128
},
129+
init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else None,
132130
)
133131

134132
# Detecting last checkpoint and eventually continue from last checkpoint
@@ -538,7 +536,7 @@ def is_audio_in_length_range(length):
538536
logger.info(f"Dataset saved at {data_args.save_to_disk}")
539537

540538
audio_max_length = None
541-
if training_args.torch_compile:
539+
if padding == "max_length":
542540
audio_max_length = max(vectorized_datasets["train"]["target_length"])
543541
with accelerator.main_process_first():
544542
max_sample = vectorized_datasets["train"].filter(
@@ -548,6 +546,18 @@ def is_audio_in_length_range(length):
548546
)
549547
audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
550548

549+
if training_args.group_by_length:
550+
# apply a simple heuristic to take into account audio and text lengths
551+
def add_target_lengths(target_length, prompt, description):
552+
return {"target_length": target_length + len(prompt) + len(description)}
553+
554+
with accelerator.main_process_first():
555+
vectorized_datasets = vectorized_datasets.map(
556+
add_target_lengths,
557+
num_proc=num_workers,
558+
input_columns=["target_length", "prompt_input_ids", "input_ids"],
559+
)
560+
551561
# for large datasets it is advised to run the preprocessing on a
552562
# single machine first with ``args.preprocessing_only`` since there will mostly likely
553563
# be a timeout when running the script in distributed mode.

0 commit comments

Comments
 (0)