Skip to content

Commit 529d856

Browse files
committed
clean-up eval scripts
1 parent 7abadc4 commit 529d856

File tree

8 files changed

+233
-30
lines changed

8 files changed

+233
-30
lines changed

src/f5_tts/eval/README.md

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,20 @@ pip install -e .[eval]
1414
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
1515
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
1616
3. Unzip the downloaded datasets and place them in the `data/` directory.
17-
4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
18-
5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
17+
4. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
1918

2019
### Batch Inference for Test Set
2120

2221
To run batch inference for evaluations, execute the following commands:
2322

2423
```bash
25-
# batch inference for evaluations
26-
accelerate config # if not set before
24+
# if not setup accelerate config yet
25+
accelerate config
26+
27+
# if only perform inference
28+
bash src/f5_tts/eval/eval_infer_batch.sh --infer-only
29+
30+
# if inference and with corresponding evaluation, setup the following tools first
2731
bash src/f5_tts/eval/eval_infer_batch.sh
2832
```
2933

@@ -35,9 +39,13 @@ bash src/f5_tts/eval/eval_infer_batch.sh
3539
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
3640
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
3741

38-
Then update in the following scripts with the paths you put evaluation model ckpts to.
42+
> [!NOTE]
43+
> ASR model will be automatically downloaded if `--local` not set for evaluation scripts.
44+
> Otherwise, you should update the `asr_ckpt_dir` path values in `eval_librispeech_test_clean.py` or `eval_seedtts_testset.py`.
45+
>
46+
> WavLM model must be downloaded and your `wavlm_ckpt_dir` path updated in `eval_librispeech_test_clean.py` and `eval_seedtts_testset.py`.
3947
40-
### Objective Evaluation
48+
### Objective Evaluation Examples
4149

4250
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
4351
```bash
@@ -50,3 +58,6 @@ python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_
5058
# Evaluation [UTMOS]. --ext: Audio extension
5159
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
5260
```
61+
62+
> [!NOTE]
63+
> Evaluation results can also be found in `_*_results.jsonl` files saved in `<GEN_WAV_DIR>`/`<WAV_DIR>`.

src/f5_tts/eval/eval_infer_batch.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def main():
4848
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
4949

5050
parser.add_argument("-t", "--testset", required=True)
51+
parser.add_argument(
52+
"-p", "--librispeech_test_clean_path", default=f"{rel_path}/data/LibriSpeech/test-clean", type=str
53+
)
54+
55+
parser.add_argument("--local", action="store_true", help="Use local vocoder checkpoint directory")
5156

5257
args = parser.parse_args()
5358

@@ -83,7 +88,7 @@ def main():
8388

8489
if testset == "ls_pc_test_clean":
8590
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
86-
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
91+
librispeech_test_clean_path = args.librispeech_test_clean_path
8792
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
8893

8994
elif testset == "seedtts_test_zh":
@@ -121,7 +126,7 @@ def main():
121126
)
122127

123128
# Vocoder model
124-
local = False
129+
local = args.local
125130
if mel_spec_type == "vocos":
126131
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
127132
elif mel_spec_type == "bigvgan":
Lines changed: 111 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,116 @@
11
#!/bin/bash
2+
set -e
3+
export PYTHONWARNINGS="ignore::UserWarning,ignore::FutureWarning"
24

3-
# e.g. F5-TTS, 16 NFE
4-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
5-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
6-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
5+
# Configuration parameters
6+
MODEL_NAME="F5TTS_v1_Base"
7+
SEEDS=(0 1 2)
8+
CKPTSTEPS=(1250000)
9+
TASKS=("seedtts_test_zh" "seedtts_test_en" "ls_pc_test_clean")
10+
LS_TEST_CLEAN_PATH="data/LibriSpeech/test-clean"
11+
GPUS="[0,1,2,3,4,5,6,7]"
12+
OFFLINE_MODE=false
713

8-
# e.g. Vanilla E2 TTS, 32 NFE
9-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
10-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
11-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
14+
# Parse arguments
15+
if [ $OFFLINE_MODE = true ]; then
16+
LOCAL="--local"
17+
else
18+
LOCAL=""
19+
fi
20+
INFER_ONLY=false
21+
while [[ $# -gt 0 ]]; do
22+
case $1 in
23+
--infer-only)
24+
INFER_ONLY=true
25+
shift
26+
;;
27+
*)
28+
echo "======== Unknown parameter: $1"
29+
exit 1
30+
;;
31+
esac
32+
done
1233

13-
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
14-
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
15-
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
16-
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0
34+
echo "======== Starting F5-TTS batch evaluation task..."
35+
if [ "$INFER_ONLY" = true ]; then
36+
echo "======== Mode: Execute infer tasks only"
37+
else
38+
echo "======== Mode: Execute full pipeline (infer + eval)"
39+
fi
1740

18-
# etc.
41+
# Function: Execute eval tasks
42+
execute_eval_tasks() {
43+
local ckptstep=$1
44+
local seed=$2
45+
local task_name=$3
46+
47+
local gen_wav_dir="results/${MODEL_NAME}_${ckptstep}/${task_name}/seed${seed}_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0"
48+
49+
echo ">>>>>>>> Starting eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
50+
51+
case $task_name in
52+
"seedtts_test_zh")
53+
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
54+
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
55+
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
56+
;;
57+
"seedtts_test_en")
58+
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
59+
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
60+
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
61+
;;
62+
"ls_pc_test_clean")
63+
python src/f5_tts/eval/eval_librispeech_test_clean.py -e wer -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
64+
python src/f5_tts/eval/eval_librispeech_test_clean.py -e sim -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
65+
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
66+
;;
67+
esac
68+
69+
echo ">>>>>>>> Completed eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
70+
}
71+
72+
# Main execution loop
73+
for ckptstep in "${CKPTSTEPS[@]}"; do
74+
echo "======== Processing ckptstep: ${ckptstep}"
75+
76+
for seed in "${SEEDS[@]}"; do
77+
echo "-------- Processing seed: ${seed}"
78+
79+
# Store eval task PIDs for current seed (if not infer-only mode)
80+
if [ "$INFER_ONLY" = false ]; then
81+
declare -a eval_pids
82+
fi
83+
84+
# Execute each infer task sequentially
85+
for task in "${TASKS[@]}"; do
86+
echo ">>>>>>>> Executing infer task: accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n \"${MODEL_NAME}\" -t \"${task}\" -c ${ckptstep} $LOCAL"
87+
88+
# Execute infer task (foreground execution, wait for completion)
89+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n "${MODEL_NAME}" -t "${task}" -c ${ckptstep} -p "${LS_TEST_CLEAN_PATH}" $LOCAL
90+
91+
# If not infer-only mode, launch corresponding eval task
92+
if [ "$INFER_ONLY" = false ]; then
93+
# Launch corresponding eval task (background execution, non-blocking for next infer)
94+
execute_eval_tasks $ckptstep $seed $task &
95+
eval_pids+=($!)
96+
fi
97+
done
98+
99+
# If not infer-only mode, wait for all eval tasks of current seed to complete
100+
if [ "$INFER_ONLY" = false ]; then
101+
echo ">>>>>>>> All infer tasks for seed ${seed} completed, waiting for corresponding eval tasks to finish..."
102+
103+
for pid in "${eval_pids[@]}"; do
104+
wait $pid
105+
done
106+
107+
unset eval_pids # Clean up array
108+
fi
109+
echo "-------- All eval tasks for seed ${seed} completed"
110+
done
111+
112+
echo "======== Completed ckptstep: ${ckptstep}"
113+
echo
114+
done
115+
116+
echo "======== All tasks completed!"
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
3+
# e.g. F5-TTS, 16 NFE
4+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
5+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
6+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16 -p data/LibriSpeech/test-clean
7+
8+
# e.g. Vanilla E2 TTS, 32 NFE
9+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
10+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
11+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0 -p data/LibriSpeech/test-clean
12+
13+
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
14+
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
15+
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
16+
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0
17+
18+
# etc.

src/f5_tts/eval/eval_librispeech_test_clean.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
22

33
import argparse
4+
import ast
45
import json
56
import os
67
import sys
@@ -25,11 +26,26 @@ def get_args():
2526
parser.add_argument("-l", "--lang", type=str, default="en")
2627
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
2728
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
28-
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
29+
parser.add_argument(
30+
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
31+
)
2932
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
3033
return parser.parse_args()
3134

3235

36+
def parse_gpu_nums(gpu_nums_str):
37+
try:
38+
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
39+
gpu_list = ast.literal_eval(gpu_nums_str)
40+
if isinstance(gpu_list, list):
41+
return gpu_list
42+
return list(range(int(gpu_nums_str)))
43+
except (ValueError, SyntaxError):
44+
raise argparse.ArgumentTypeError(
45+
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
46+
)
47+
48+
3349
def main():
3450
args = get_args()
3551
eval_task = args.eval_task
@@ -38,7 +54,7 @@ def main():
3854
gen_wav_dir = args.gen_wav_dir
3955
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
4056

41-
gpus = list(range(args.gpu_nums))
57+
gpus = parse_gpu_nums(args.gpu_nums)
4258
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
4359

4460
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,

src/f5_tts/eval/eval_seedtts_testset.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Evaluate with Seed-TTS testset
22

33
import argparse
4+
import ast
45
import json
56
import os
67
import sys
@@ -24,11 +25,26 @@ def get_args():
2425
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
2526
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
2627
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
27-
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
28+
parser.add_argument(
29+
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
30+
)
2831
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
2932
return parser.parse_args()
3033

3134

35+
def parse_gpu_nums(gpu_nums_str):
36+
try:
37+
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
38+
gpu_list = ast.literal_eval(gpu_nums_str)
39+
if isinstance(gpu_list, list):
40+
return gpu_list
41+
return list(range(int(gpu_nums_str)))
42+
except (ValueError, SyntaxError):
43+
raise argparse.ArgumentTypeError(
44+
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
45+
)
46+
47+
3248
def main():
3349
args = get_args()
3450
eval_task = args.eval_task
@@ -38,7 +54,7 @@ def main():
3854

3955
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
4056
# zh 1.254 seems a result of 4 workers wer_seed_tts
41-
gpus = list(range(args.gpu_nums))
57+
gpus = parse_gpu_nums(args.gpu_nums)
4258
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
4359

4460
local = args.local

src/f5_tts/eval/utils_eval.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,14 +395,21 @@ def run_sim(args):
395395
wav1, sr1 = torchaudio.load(gen_wav)
396396
wav2, sr2 = torchaudio.load(prompt_wav)
397397

398-
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
399-
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
400-
wav1 = resample1(wav1)
401-
wav2 = resample2(wav2)
402-
403398
if use_gpu:
404399
wav1 = wav1.cuda(device)
405400
wav2 = wav2.cuda(device)
401+
402+
if sr1 != 16000:
403+
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
404+
if use_gpu:
405+
resample1 = resample1.cuda(device)
406+
wav1 = resample1(wav1)
407+
if sr2 != 16000:
408+
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
409+
if use_gpu:
410+
resample2 = resample2.cuda(device)
411+
wav2 = resample2(wav2)
412+
406413
with torch.no_grad():
407414
emb1 = model(wav1)
408415
emb2 = model(wav2)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import math
2+
3+
from torch.utils.data import SequentialSampler
4+
5+
from f5_tts.model.dataset import DynamicBatchSampler, load_dataset
6+
7+
8+
train_dataset = load_dataset("Emilia_ZH_EN", "pinyin")
9+
sampler = SequentialSampler(train_dataset)
10+
11+
gpus = 8
12+
batch_size_per_gpu = 38400
13+
max_samples_per_gpu = 64
14+
max_updates = 1250000
15+
16+
batch_sampler = DynamicBatchSampler(
17+
sampler,
18+
batch_size_per_gpu,
19+
max_samples=max_samples_per_gpu,
20+
random_seed=666,
21+
drop_residual=False,
22+
)
23+
24+
print(
25+
f"One epoch has {len(batch_sampler) / gpus} updates if gpus={gpus}, with "
26+
f"batch_size_per_gpu={batch_size_per_gpu} (frames) & "
27+
f"max_samples_per_gpu={max_samples_per_gpu}."
28+
)
29+
print(
30+
f"If gpus={gpus}, for max_updates={max_updates} "
31+
f"should set epoch={math.ceil(max_updates / len(batch_sampler) * gpus)}."
32+
)

0 commit comments

Comments
 (0)