|
| 1 | +import random |
| 2 | + |
| 3 | +import nltk |
| 4 | +import parler_tts |
| 5 | +import torch |
| 6 | +import torchaudio |
| 7 | +import transformers |
| 8 | +from nltk.tokenize import sent_tokenize |
| 9 | + |
| 10 | +# Download NLTK's punkt tokenizer data if not already downloaded |
| 11 | +global _nltk_initialized |
| 12 | +_nltk_initialized = False |
| 13 | + |
| 14 | +# Function to split text into chunks using sentence tokenization |
| 15 | +def nltk_chunk_text(text): |
| 16 | + global _nltk_initialized |
| 17 | + if not _nltk_initialized: |
| 18 | + nltk.download('punkt') |
| 19 | + _nltk_initialized = True |
| 20 | + return sent_tokenize(text) |
| 21 | + |
| 22 | +_tts_speaker_prompts = [ |
| 23 | + "A female speaker delivers an expressive and animated speech with a very high-pitch voice. " |
| 24 | + "The recording is slightly noisy but of good quality, as her voice comes across as very close-sounding.", |
| 25 | + "A female speaker delivers her speech with a slightly expressive and animated tone, " |
| 26 | + "her voice ringing clearly and undistorted in the recording. " |
| 27 | + "The pitch of her voice is very high, adding a sense of urgency and excitement.", |
| 28 | + "A female speaks with a slightly expressive and animated tone in a recording that sounds quite clear and close up. " |
| 29 | + "There is only a mild amount of background noise present, and her voice has a moderate pitch. " |
| 30 | + "Her speech pace is steady, neither slow nor particularly fast.", |
| 31 | + "A female speaker delivers her speech in a recording that sounds clear and close up. " |
| 32 | + "Her voice is slightly expressive and animated, with a moderate pitch. " |
| 33 | + "The recording has a mild amount of background noise, but her voice is still easily understood.", |
| 34 | + "In a somewhat confined space, a female speaker delivers a talk that is slightly expressive and animated, " |
| 35 | + "despite some background noise. " |
| 36 | + "Her voice has a low-pitch tone.", |
| 37 | + "A male voice speaks in a monotone tone with a slightly low-pitch, delivering his words at a moderate speed. " |
| 38 | + "The recording offers almost no noise, resulting in a very clear and high-quality listen. " |
| 39 | + "The close-up microphone captures every detail of his speech.", |
| 40 | + "A man speaks with a monotone tone and a slightly low-pitch, delivering his words at a moderate speed. " |
| 41 | + "The recording captures his speech very clearly and distinctly, with little to no background noise. " |
| 42 | + "The listener feels as if they're almost sharing the same space with the speaker.", |
| 43 | + "A male speaker delivers his words with a very monotone and slightly faster than average pace. " |
| 44 | + "His voice is very clear, making every word distinct, while it also has a slightly low-pitch tone. " |
| 45 | + "The recording quality is excellent, with no apparent reverberation or background noise.", |
| 46 | + "A male speaker delivers his words in a very monotone and slightly low-pitched voice, " |
| 47 | + "maintaining a moderate speed. The recording is of very high quality, with minimum noise " |
| 48 | + "and a very close-sounding reverberation that suggests a quiet and enclosed environment.", |
| 49 | +] |
| 50 | + |
| 51 | + |
| 52 | +global _tts_models |
| 53 | +_tts_models = {} |
| 54 | + |
| 55 | + |
| 56 | +def text_to_speech( |
| 57 | + text, |
| 58 | + prompt=None, |
| 59 | + device=None, |
| 60 | + model_name="parler-tts/parler-tts-mini-multilingual-v1.1", |
| 61 | + sampling_rate=16_000, |
| 62 | +): |
| 63 | + global _tts_models |
| 64 | + |
| 65 | + # Set up device |
| 66 | + if device is None: |
| 67 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 68 | + elif isinstance(device, str): |
| 69 | + device = torch.device(device) |
| 70 | + |
| 71 | + if prompt is None: |
| 72 | + prompt = random.choice(_tts_speaker_prompts) |
| 73 | + elif isinstance(prompt, list): |
| 74 | + prompt = random.choice(prompt) |
| 75 | + elif isinstance(prompt, str): |
| 76 | + pass |
| 77 | + else: |
| 78 | + raise ValueError("Prompt must be a string or a list of strings") |
| 79 | + |
| 80 | + # Load processor and model from Hugging Face, with caching in (V)RAM |
| 81 | + if model_name not in _tts_models: |
| 82 | + |
| 83 | + model = parler_tts.ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(device) |
| 84 | + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) |
| 85 | + description_tokenizer = transformers.AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path) |
| 86 | + model_sampling_rate = model.config.sampling_rate |
| 87 | + |
| 88 | + _tts_models[model_name] = (model, tokenizer, description_tokenizer, model_sampling_rate) |
| 89 | + |
| 90 | + (model, tokenizer, description_tokenizer, model_sampling_rate) = _tts_models[model_name] |
| 91 | + model = model.to(device) |
| 92 | + |
| 93 | + text_tokens = tokenizer(text, return_tensors="pt").input_ids.to(device) |
| 94 | + speaker_type_prompt_tokens = description_tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
| 95 | + audio_tensor = model.generate(input_ids=speaker_type_prompt_tokens, prompt_input_ids=text_tokens) |
| 96 | + |
| 97 | + if len(audio_tensor.shape) == 2 and audio_tensor.shape[0] == 1: |
| 98 | + audio_tensor = audio_tensor[0] |
| 99 | + |
| 100 | + audio_tensor = audio_tensor.to("cpu") |
| 101 | + |
| 102 | + if sampling_rate != model_sampling_rate: |
| 103 | + audio_tensor = torchaudio.transforms.Resample(model_sampling_rate, sampling_rate)(audio_tensor) |
| 104 | + |
| 105 | + audio_tensor = audio_tensor.numpy() |
| 106 | + |
| 107 | + return audio_tensor |
| 108 | + |
| 109 | +if __name__ == "__main__": |
| 110 | + |
| 111 | + import argparse |
| 112 | + import os |
| 113 | + |
| 114 | + from audio import save_audio |
| 115 | + parser = argparse.ArgumentParser() |
| 116 | + parser.add_argument("words", type=str, nargs="+", help="Text to convert to speech") |
| 117 | + parser.add_argument("--device", type=str, default=None, help="Device to use for inference") |
| 118 | + parser.add_argument("--model_name", type=str, default="parler-tts/parler-tts-mini-multilingual-v1.1", |
| 119 | + help="Model name or path") |
| 120 | + parser.add_argument("--output", type=str, default="out", help="Output folder name") |
| 121 | + parser.add_argument("--num", type=int, default=10, help="Number of generations") |
| 122 | + args = parser.parse_args() |
| 123 | + |
| 124 | + text = " ".join(args.words) |
| 125 | + |
| 126 | + for i in range(args.num): |
| 127 | + prompt = random.choice(_tts_speaker_prompts) |
| 128 | + audio_tensor = text_to_speech(text, prompt, model_name=args.model_name, device=args.device) |
| 129 | + os.makedirs(args.output, exist_ok=True) |
| 130 | + with open(os.path.join(args.output, f"audio_{i:03d}_prompt.txt"), "w") as f: |
| 131 | + f.write(prompt) |
| 132 | + save_audio(os.path.join(args.output, f"audio_{i:03d}.wav"), audio_tensor) |
| 133 | + |
0 commit comments