-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathsetup_ljspeech.py
More file actions
583 lines (477 loc) · 19.8 KB
/
setup_ljspeech.py
File metadata and controls
583 lines (477 loc) · 19.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
#!/usr/bin/env python3
"""
Setup script for LJSpeech dataset
Downloads dataset and optionally runs Montreal Forced Aligner
"""
import os
import sys
import subprocess
import argparse
from pathlib import Path
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Dataset URLs
LJSPEECH_URL = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
LJSPEECH_ZENODO_ALIGNMENTS_URL = "https://zenodo.org/api/records/7499098/files/grids.zip/content"
LJSPEECH_DIR = "LJSpeech-1.1"
LJSPEECH_ARCHIVE = "LJSpeech-1.1.tar.bz2"
ALIGNMENTS_ARCHIVE = "grids.zip"
def check_command_exists(command: str) -> bool:
"""Check if a command exists in PATH"""
try:
subprocess.run(
[command, "--version"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=False
)
return True
except FileNotFoundError:
return False
def download_ljspeech(output_dir: str = "."):
"""Download LJSpeech dataset
Args:
output_dir: Directory to download to
"""
output_path = Path(output_dir)
archive_path = output_path / LJSPEECH_ARCHIVE
dataset_path = output_path / LJSPEECH_DIR
# Check if already downloaded
if dataset_path.exists():
logger.info(f"LJSpeech dataset already exists at: {dataset_path}")
response = input("Re-download? (y/N): ").strip().lower()
if response != 'y':
logger.info("Skipping download")
return str(dataset_path)
# Download
url = LJSPEECH_URL
size = "2.6 GB"
logger.info("Downloading LJSpeech dataset")
logger.info(f"Downloading from: {url}")
logger.info(f"This is a {size} download - it may take a while...")
try:
if check_command_exists("wget"):
subprocess.run(
["wget", "-O", str(archive_path), url],
check=True
)
elif check_command_exists("curl"):
subprocess.run(
["curl", "-L", "-o", str(archive_path), url],
check=True
)
else:
logger.error("Neither wget nor curl found. Please install one of them.")
logger.info("Or download manually from: " + url)
sys.exit(1)
logger.info("Download complete")
except subprocess.CalledProcessError as e:
logger.error(f"Download failed: {e}")
sys.exit(1)
# Extract
logger.info("Extracting archive...")
try:
subprocess.run(
["tar", "-xjf", str(archive_path), "-C", str(output_path)],
check=True
)
logger.info("Extraction complete")
# Remove archive to save space
logger.info("Removing archive to save space...")
archive_path.unlink()
except subprocess.CalledProcessError as e:
logger.error(f"Extraction failed: {e}")
sys.exit(1)
logger.info(f"LJSpeech dataset ready at: {dataset_path}")
return str(dataset_path)
def download_zenodo_alignments(dataset_path: str):
"""Download pre-computed MFA alignments from Zenodo
Args:
dataset_path: Path to LJSpeech dataset directory
"""
dataset_path = Path(dataset_path)
output_path = dataset_path.parent
archive_path = output_path / ALIGNMENTS_ARCHIVE
textgrid_path = dataset_path / "TextGrid"
# Check if alignments already exist
if textgrid_path.exists():
logger.info(f"TextGrid alignments already exist at: {textgrid_path}")
response = input("Re-download? (y/N): ").strip().lower()
if response != 'y':
logger.info("Skipping alignment download")
return str(textgrid_path)
# Download alignments
url = LJSPEECH_ZENODO_ALIGNMENTS_URL
logger.info("Downloading pre-computed MFA alignments from Zenodo")
logger.info(f"Downloading from: {url}")
try:
if check_command_exists("wget"):
subprocess.run(
["wget", "-O", str(archive_path), url],
check=True
)
elif check_command_exists("curl"):
subprocess.run(
["curl", "-L", "-o", str(archive_path), url],
check=True
)
else:
logger.error("Neither wget nor curl found. Please install one of them.")
sys.exit(1)
logger.info("Download complete")
except subprocess.CalledProcessError as e:
logger.error(f"Alignment download failed: {e}")
sys.exit(1)
# Extract to TextGrid directory
logger.info("Extracting alignments...")
try:
# Create TextGrid directory if it doesn't exist
textgrid_path.mkdir(parents=True, exist_ok=True)
# Extract directly into TextGrid directory
subprocess.run(
["unzip", "-q", str(archive_path), "-d", str(textgrid_path)],
check=True
)
# Check if files were extracted to a grids subdirectory
grids_path = textgrid_path / "grids"
if grids_path.exists():
# Move files from grids/ to TextGrid/ and remove grids/
logger.info("Reorganizing extracted files...")
for textgrid_file in grids_path.glob("*.TextGrid"):
textgrid_file.rename(textgrid_path / textgrid_file.name)
grids_path.rmdir()
logger.info("Extraction complete")
# Remove archive to save space
logger.info("Removing archive to save space...")
archive_path.unlink()
except subprocess.CalledProcessError as e:
logger.error(f"Extraction failed: {e}")
logger.info("Make sure 'unzip' is installed")
sys.exit(1)
logger.info(f"Pre-computed alignments ready at: {textgrid_path}")
return str(textgrid_path)
def setup_mfa():
"""Check if MFA is installed and provide setup instructions"""
logger.info("\nChecking for Montreal Forced Aligner (MFA)...")
if check_command_exists("mfa"):
logger.info("MFA is installed!")
# Check version
result = subprocess.run(
["mfa", "version"],
capture_output=True,
text=True
)
logger.info(f"MFA version: {result.stdout.strip()}")
return True
else:
logger.warning("Montreal Forced Aligner (MFA) not found")
logger.info("\nMFA is required for generating phoneme duration alignments.")
logger.info("Without MFA, the model will use uniform duration fallback (poor quality).")
logger.info("\nTo install MFA:")
logger.info(" 1. Install conda if not already installed:")
logger.info(" https://docs.conda.io/en/latest/miniconda.html")
logger.info("\n 2. Install MFA via conda:")
logger.info(" conda install -c conda-forge montreal-forced-aligner")
logger.info("\n 3. Re-run this script with --align flag")
return False
def run_mfa_alignment(dataset_path: str, use_custom_dict: bool = False):
"""
Run Montreal Forced Aligner on LJSpeech
Args:
dataset_path: Path to LJSpeech dataset
use_custom_dict: If True, creates custom dictionary (legacy - not needed for g2p_en)
"""
logger.info("\nRunning Montreal Forced Aligner...")
dataset_path = Path(dataset_path)
# MFA needs to point to the directory containing .wav and .txt files
corpus_path = dataset_path / "wavs"
output_path = dataset_path / "TextGrid" / "wavs"
if output_path.exists():
logger.info(f"Alignments already exist at: {output_path}")
response = input("Re-run alignment? (y/N): ").strip().lower()
if response != 'y':
logger.info("Skipping alignment")
return str(output_path)
# Check if MFA is installed
if not check_command_exists("mfa"):
logger.error("MFA is not installed. Run setup without --align first.")
sys.exit(1)
logger.info("This process will take 1-3 hours depending on your hardware...")
# Create parent TextGrid directory
(dataset_path / "TextGrid").mkdir(parents=True, exist_ok=True)
# Step 1: Create individual .txt files for MFA (required!)
logger.info("\n" + "="*70)
logger.info("Step 1/4: Creating transcription files for MFA")
logger.info("="*70)
metadata_file = dataset_path / "metadata.csv"
if not metadata_file.exists():
logger.error(f"metadata.csv not found at {metadata_file}")
sys.exit(1)
logger.info(f"Reading metadata from {metadata_file}...")
# Import text normalization
import re
try:
import inflect
p = inflect.engine()
has_inflect = True
except ImportError:
logger.warning("inflect not installed - numbers will be kept as-is (may cause mismatches)")
logger.warning("Install with: pip install inflect")
has_inflect = False
def normalize_text(text):
"""Normalize text to match what g2p_en expects"""
if not has_inflect:
return text
# Convert numbers to words (handles years, decimals, etc.)
def convert_number(match):
num_str = match.group(0)
try:
# Try to convert to number and then to words
if '.' in num_str:
return p.number_to_words(float(num_str))
else:
num = int(num_str)
# For years (4 digits), say as individual digits pairs
if 1000 <= num <= 2099:
# e.g., 1929 -> "nineteen twenty nine"
return p.number_to_words(num, group=2)
else:
return p.number_to_words(num)
except:
return num_str
# Replace numbers with words
text = re.sub(r'\b\d+\.?\d*\b', convert_number, text)
return text
txt_created = 0
txt_normalized = 0
with open(metadata_file, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split('|')
if len(parts) >= 2:
file_id = parts[0]
transcription = parts[1]
# Normalize text so MFA and g2p_en see the same thing
normalized_text = normalize_text(transcription)
if normalized_text != transcription:
txt_normalized += 1
# Create .txt file alongside .wav file
txt_path = corpus_path / f"{file_id}.txt"
# Always overwrite to ensure normalization is applied
with open(txt_path, 'w', encoding='utf-8') as txt_f:
txt_f.write(normalized_text)
txt_created += 1
logger.info(f"✓ Created {txt_created} transcription .txt files")
logger.info(f"✓ Normalized {txt_normalized} files (numbers → words)")
logger.info(f"✓ MFA and g2p_en will now process the same text")
try:
# Step 2: Custom dictionary (legacy - not needed with g2p_en)
custom_dict_path = None
if use_custom_dict:
logger.warning("\n⚠️ Custom dictionary generation is legacy functionality")
logger.warning("Not needed when using g2p_en - it already matches english_us_arpa!")
logger.info("\nFalling back to standard english_us_arpa dictionary...")
custom_dict_path = None
# Step 3: Download dictionary and acoustic model
logger.info("\n" + "="*70)
logger.info("Step 3/4: Downloading MFA models")
logger.info("="*70)
logger.info("Downloading english_us_arpa dictionary...")
subprocess.run(
["mfa", "model", "download", "dictionary", "english_us_arpa"],
check=True
)
logger.info("Downloading english_us_arpa acoustic model...")
subprocess.run(
["mfa", "model", "download", "acoustic", "english_us_arpa"],
check=True
)
# Step 4: Run alignment
logger.info("\n" + "="*70)
logger.info("Step 4/4: Running forced alignment")
logger.info("="*70)
logger.info(f"Input corpus: {corpus_path}")
logger.info(f"Output: {output_path}")
logger.info("Using standard english_us_arpa dictionary")
logger.info("✓ This matches g2p_en phoneme output perfectly!")
dictionary = "english_us_arpa"
# MFA align command
# mfa align <corpus_dir> <dictionary> <acoustic_model> <output_dir>
subprocess.run(
[
"mfa", "align",
str(corpus_path), # Point to wavs/ directory with .wav and .txt files
dictionary, # custom or standard dictionary
"english_us_arpa", # acoustic model
str(output_path), # output directory (wavs subfolder)
"--clean", # Clean previous runs
"--verbose" # Verbose output
],
check=True
)
logger.info(f"\n✓ Alignment complete! TextGrid files saved to: {output_path}")
# Verify alignment output
num_textgrids = len(list(output_path.glob("*.TextGrid")))
logger.info(f"✓ Created {num_textgrids} TextGrid files")
if custom_dict_path:
logger.info("\n" + "="*70)
logger.info("✓ SUCCESS: Alignments use Misaki G2P-compatible phonemes!")
logger.info("="*70)
logger.info("This means:")
logger.info(" • No more phoneme count mismatches")
logger.info(" • 100% of samples will use real MFA durations")
logger.info(" • No silent fallback to uniform durations")
return str(output_path)
except subprocess.CalledProcessError as e:
logger.error(f"MFA alignment failed: {e}")
logger.info("\nTroubleshooting:")
logger.info(" - Make sure conda is activated")
logger.info(" - Try running MFA commands manually to see detailed errors")
logger.info(" - Check MFA documentation: https://montreal-forced-aligner.readthedocs.io/")
sys.exit(1)
def verify_installation(dataset_path: str):
"""Verify the installation"""
logger.info("\nVerifying installation...")
dataset_path = Path(dataset_path)
# Check metadata
metadata_file = dataset_path / "metadata.csv"
if not metadata_file.exists():
logger.error(f"Metadata file not found: {metadata_file}")
return False
# Count samples
with open(metadata_file, 'r', encoding='utf-8') as f:
num_samples = sum(1 for _ in f)
logger.info(f"✓ Metadata file: {num_samples} samples")
# Check wavs
wavs_dir = dataset_path / "wavs"
if wavs_dir.exists():
num_wavs = len(list(wavs_dir.glob("*.wav")))
logger.info(f"✓ Audio files: {num_wavs} WAV files")
else:
logger.error("✗ Audio directory not found")
return False
# Check alignments
textgrid_dir = dataset_path / "TextGrid"
if textgrid_dir.exists():
num_textgrids = len(list(textgrid_dir.glob("*.TextGrid")))
logger.info(f"✓ MFA alignments: {num_textgrids} TextGrid files")
else:
logger.warning("✗ No MFA alignments found (will use uniform fallback)")
logger.info("\nDataset structure:")
logger.info(f" {dataset_path}/")
logger.info(f" metadata.csv ({num_samples} entries)")
logger.info(f" wavs/ ({num_wavs} files)")
if textgrid_dir.exists():
logger.info(f" TextGrid/ ({num_textgrids} files)")
logger.info("\n✓ Dataset is ready for training!")
return True
def main():
parser = argparse.ArgumentParser(
description="Setup LJSpeech dataset for Kokoro English TTS training"
)
parser.add_argument(
'--output-dir',
type=str,
default='.',
help='Directory to download dataset to (default: current directory)'
)
parser.add_argument(
'--zenodo',
action='store_true',
help='Download pre-computed MFA alignments from Zenodo (faster than running MFA locally)'
)
parser.add_argument(
'--align',
action='store_true',
help='Run Montreal Forced Aligner after download (not needed if using --zenodo)'
)
parser.add_argument(
'--skip-download',
action='store_true',
help='Skip download (use existing dataset)'
)
parser.add_argument(
'--align-only',
action='store_true',
help='Only run alignment (assumes dataset already downloaded)'
)
parser.add_argument(
'--no-custom-dict',
action='store_true',
help='Use standard MFA dictionary (english_us_arpa) - RECOMMENDED since we use g2p_en which matches MFA'
)
args = parser.parse_args()
print("\n" + "="*70)
print("LJSpeech Dataset Setup")
print("="*70 + "\n")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Check for conflicting flags
if args.zenodo and args.align:
logger.warning("Using both --zenodo and --align is redundant")
logger.info("--zenodo will download pre-computed alignments, --align will run MFA locally")
logger.info("You typically only need one of these options")
dataset_path = None
# Download dataset
if not args.skip_download and not args.align_only:
dataset_path = download_ljspeech(str(output_dir))
else:
dataset_path = str(output_dir / LJSPEECH_DIR)
if not Path(dataset_path).exists():
logger.error(f"Dataset not found at: {dataset_path}")
logger.info("Run without --skip-download to download it")
sys.exit(1)
# Download pre-computed alignments from Zenodo if requested
if args.zenodo and not args.align_only:
logger.info("\n" + "="*70)
logger.info("Downloading pre-computed alignments from Zenodo")
logger.info("="*70 + "\n")
download_zenodo_alignments(dataset_path)
mfa_installed = False # Skip MFA setup since we have alignments
else:
# Setup/check MFA for local alignment
mfa_installed = setup_mfa()
# Run alignment locally if requested
if args.align or args.align_only:
use_custom_dict = not args.no_custom_dict
if args.zenodo and not args.align_only:
logger.warning("Zenodo alignments already downloaded - local MFA not needed")
response = input("Run MFA alignment anyway? (y/N): ").strip().lower()
if response != 'y':
logger.info("Skipping local MFA alignment")
elif not mfa_installed:
logger.error("Cannot run alignment - MFA is not installed")
sys.exit(1)
else:
run_mfa_alignment(dataset_path, use_custom_dict=use_custom_dict)
elif not mfa_installed:
logger.error("Cannot run alignment - MFA is not installed")
sys.exit(1)
else:
run_mfa_alignment(dataset_path, use_custom_dict=use_custom_dict)
# Verify
verify_installation(dataset_path)
# Next steps
print("\n" + "="*70)
print("Next Steps")
print("="*70)
if not Path(dataset_path).joinpath("TextGrid").exists():
print("\n⚠️ No MFA alignments found!")
print("\nFor better quality, get pre-computed alignments:")
print(f" python setup_ljspeech.py --zenodo --skip-download")
print("\nOr run MFA alignment locally (takes 1-3 hours):")
print(f" python setup_ljspeech.py --align-only")
print("\nOr train without alignments (lower quality):")
print(f" python training_english.py --corpus {dataset_path}")
else:
print("\n✓ Dataset is ready with MFA alignments!")
print("\nStart training:")
print(f" python training_english.py --corpus {dataset_path}")
print("\nTest mode (quick test with small subset):")
print(f" python training_english.py --corpus {dataset_path} --test-mode")
print("\n" + "="*70 + "\n")
if __name__ == "__main__":
main()