Skip to content

Commit d242e6a

Browse files
committed
Enables async podcast generation with task tracking and UI stop feature.
1 parent e87ef69 commit d242e6a

File tree

3 files changed

+420
-97
lines changed

3 files changed

+420
-97
lines changed

app.py

Lines changed: 80 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from flask import Flask, render_template, request, jsonify, send_from_directory
2-
from generate_podcast import generate, PODCAST_SCRIPT, setup_logging, validate_speakers, update_elevenlabs_quota
2+
from generate_podcast import generate, DEFAULT_INSTRUCTION, DEFAULT_SCRIPT, setup_logging, validate_speakers, update_elevenlabs_quota
33
from utils import sanitize_text, get_asset_path, get_app_data_dir
44
from config import AVAILABLE_VOICES, DEFAULT_APP_SETTINGS, DEMO_AVAILABLE
55
from create_demo import create_html_demo_whisperx
@@ -11,11 +11,16 @@
1111
import shutil
1212
from elevenlabs.core import ApiError
1313
import re
14+
import uuid
15+
import threading
1416

1517
# --- App Initialization ---
1618
app = Flask(__name__)
1719
logger = setup_logging()
1820

21+
# --- In-Memory Task Manager ---
22+
tasks = {}
23+
1924
# --- Version & License ---
2025
try:
2126
from _version import __version__
@@ -55,7 +60,10 @@ def save_settings(settings):
5560
# --- Routes ---
5661
@app.route('/')
5762
def index():
58-
return render_template('index.html', default_script=PODCAST_SCRIPT, demo_available=DEMO_AVAILABLE)
63+
return render_template('index.html',
64+
default_instruction=DEFAULT_INSTRUCTION,
65+
default_script=DEFAULT_SCRIPT,
66+
demo_available=DEMO_AVAILABLE)
5967

6068
@app.route('/assets/<path:filename>')
6169
def get_asset(filename):
@@ -130,6 +138,38 @@ def get_gemini_sample(voice_name):
130138
return "Sample directory not found", 404
131139
return send_from_directory(sample_path, f"{voice_name}.mp3")
132140

141+
def run_generation_task(task_id, script_text, app_settings, output_filepath, api_key):
142+
"""The target function for the generation thread."""
143+
stop_event = tasks[task_id]['stop_event']
144+
try:
145+
generated_file = generate(
146+
script_text=script_text,
147+
app_settings=app_settings,
148+
output_filepath=output_filepath,
149+
api_key=api_key,
150+
status_callback=logger.info,
151+
stop_event=stop_event
152+
)
153+
if generated_file:
154+
tasks[task_id]['status'] = 'completed'
155+
tasks[task_id]['result'] = {'download_url': f'/temp/{os.path.basename(generated_file)}', 'filename': os.path.basename(generated_file)}
156+
except Exception as e:
157+
# If the exception is due to the stop event, set a specific status
158+
if "stopped by user" in str(e):
159+
tasks[task_id]['status'] = 'cancelled'
160+
tasks[task_id]['error'] = 'Generation cancelled by user.'
161+
# Clean up the partially created file
162+
if os.path.exists(output_filepath):
163+
try:
164+
os.remove(output_filepath)
165+
logger.info(f"Removed partial file for stopped task: {output_filepath}")
166+
except OSError as err:
167+
logger.error(f"Error removing partial file for stopped task: {err}")
168+
else:
169+
logger.error(f"Error during generation for task {task_id}: {e}", exc_info=True)
170+
tasks[task_id]['status'] = 'failed'
171+
tasks[task_id]['error'] = str(e)
172+
133173
@app.route('/generate', methods=['POST'])
134174
def handle_generate():
135175
script_text = request.form.get('script', '')
@@ -146,46 +186,53 @@ def handle_generate():
146186
except ValueError as e:
147187
return jsonify({'error': str(e)}), 400
148188

149-
first_words = re.sub(r'<[^>]+>', '', sanitized_script).strip().split()[:2]
150-
base_name = "_".join(first_words).lower()
151-
safe_base_name = re.sub(r'[^a-z0-9_]+', '', base_name)
152-
if not safe_base_name:
153-
safe_base_name = "podcast"
154-
random_suffix = os.urandom(4).hex()
155-
output_filename = f"{safe_base_name}_{random_suffix}.mp3"
156-
157-
output_filepath = os.path.join(app.config['TEMP_DIR'], output_filename)
158-
159189
provider = app_settings.get("tts_provider", "elevenlabs")
160190
api_key_env_var = "ELEVENLABS_API_KEY" if provider == "elevenlabs" else "GEMINI_API_KEY"
161191
api_key = os.environ.get(api_key_env_var)
162-
163192
if not api_key:
164193
return jsonify({'error': f'API key ({api_key_env_var}) not found in environment variables.'}), 500
165194

166195
from utils import sanitize_app_settings_for_backend
167196
app_settings_clean = sanitize_app_settings_for_backend(app_settings)
168197

169-
try:
170-
generated_file = generate(
171-
script_text=sanitized_script,
172-
app_settings=app_settings_clean,
173-
output_filepath=output_filepath,
174-
api_key=api_key,
175-
status_callback=logger.info
176-
)
177-
if generated_file:
178-
return jsonify({'download_url': f'/temp/{output_filename}', 'filename': output_filename})
179-
else:
180-
return jsonify({'error': 'Generation failed for an unknown reason. Check server logs.'}), 500
181-
except ApiError as e:
182-
error_detail = e.body.get('detail', {})
183-
message = error_detail.get('message', 'An unknown ElevenLabs API error occurred.')
184-
logger.error(f"ElevenLabs API Error: {message}")
185-
return jsonify({'error': f"ElevenLabs Error: {message}"}), 500
186-
except Exception as e:
187-
logger.error(f"Error during generation: {e}", exc_info=True)
188-
return jsonify({'error': f'An unexpected error occurred: {str(e)}'}), 500
198+
task_id = str(uuid.uuid4())
199+
output_filename = f"{task_id}.mp3"
200+
output_filepath = os.path.join(app.config['TEMP_DIR'], output_filename)
201+
202+
stop_event = threading.Event()
203+
thread = threading.Thread(target=run_generation_task, args=(task_id, sanitized_script, app_settings_clean, output_filepath, api_key))
204+
205+
tasks[task_id] = {'thread': thread, 'stop_event': stop_event, 'status': 'running'}
206+
thread.start()
207+
208+
return jsonify({'task_id': task_id})
209+
210+
@app.route('/api/generation_status/<task_id>', methods=['GET'])
211+
def get_generation_status(task_id):
212+
task = tasks.get(task_id)
213+
if not task:
214+
return jsonify({'error': 'Task not found'}), 404
215+
216+
response = {'status': task['status']}
217+
if task['status'] == 'completed':
218+
response['result'] = task['result']
219+
elif task['status'] in ['failed', 'cancelled']:
220+
response['error'] = task.get('error', 'An unknown error occurred.')
221+
222+
return jsonify(response)
223+
224+
@app.route('/api/stop_generation/<task_id>', methods=['POST'])
225+
def stop_generation(task_id):
226+
task = tasks.get(task_id)
227+
if not task:
228+
return jsonify({'error': 'Task not found'}), 404
229+
230+
if task['status'] == 'running':
231+
task['stop_event'].set()
232+
task['status'] = 'stopping'
233+
return jsonify({'status': 'Stop signal sent.'})
234+
235+
return jsonify({'status': 'Task was not running.'})
189236

190237
@app.route('/api/generate_demo', methods=['POST'])
191238
def handle_generate_demo():

generate_podcast.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import getpass
1313
from typing import Optional, Any, Dict, List, Tuple
1414
import tempfile
15+
import threading
1516

1617
import json
1718
import keyring # For secure credential storage
@@ -23,11 +24,12 @@
2324
# Global logger instance - initialized once when module is imported
2425
logger = logging.getLogger(__name__)
2526

26-
# The podcast script is now a constant to be used by the console mode.
27-
PODCAST_SCRIPT = """Read aloud in a warm, welcoming tone
28-
John: [playful] Who am I? I am a little old lady. My hair is white. I have got a small crown and a black handbag. My dress is blue. My country's flag is red, white and blue. I am on many coins and stamps. I love dogs, my dogs' names are corgis! Who am I??
27+
# The podcast script is now split into instruction and main script
28+
DEFAULT_INSTRUCTION = "Read aloud in a warm, welcoming tone"
29+
DEFAULT_SCRIPT = """John: [playful] Who am I? I am a little old lady. My hair is white. I have got a small crown and a black handbag. My dress is blue. My country's flag is red, white and blue. I am on many coins and stamps. I love dogs, my dogs' names are corgis! Who am I??
2930
Samantha: [laughing] You're queen Elizabeth II!!
3031
"""
32+
PODCAST_SCRIPT = f"{DEFAULT_INSTRUCTION}\n{DEFAULT_SCRIPT}"
3133

3234

3335
def setup_logging() -> logging.Logger:
@@ -127,15 +129,15 @@ def get_api_key(status_callback, logger: logging.Logger, parent_window=None, ser
127129

128130

129131
class TTSProvider:
130-
def synthesize(self, script_text: str, speaker_mapping: dict, output_filepath: str, status_callback=print) -> str:
132+
def synthesize(self, script_text: str, speaker_mapping: dict, output_filepath: str, status_callback=print, stop_event: Optional[threading.Event] = None) -> str:
131133
raise NotImplementedError
132134

133135

134136
class GeminiTTS(TTSProvider):
135137
def __init__(self, api_key: str):
136138
self.api_key = api_key
137139

138-
def synthesize(self, script_text: str, speaker_mapping: dict, output_filepath: str, status_callback=print) -> str:
140+
def synthesize(self, script_text: str, speaker_mapping: dict, output_filepath: str, status_callback=print, stop_event: Optional[threading.Event] = None) -> str:
139141
logger = logging.getLogger("PodcastGenerator")
140142
client = genai.Client(api_key=self.api_key)
141143

@@ -161,11 +163,15 @@ def synthesize(self, script_text: str, speaker_mapping: dict, output_filepath: s
161163
generate_content_config = types.GenerateContentConfig(temperature=1, response_modalities=["audio"], speech_config=speech_config)
162164

163165
for i, model_name in enumerate(models_to_try):
166+
if stop_event and stop_event.is_set():
167+
raise Exception("Generation stopped by user.")
164168
status_callback(f"\nAttempting generation with model: {model_name}...")
165169
try:
166170
audio_chunks = []
167171
final_mime_type = ""
168172
for chunk in client.models.generate_content_stream(model=model_name, contents=contents, config=generate_content_config):
173+
if stop_event and stop_event.is_set():
174+
raise Exception("Generation stopped by user during streaming.")
169175
if not (chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts):
170176
continue
171177
part = chunk.candidates[0].content.parts[0]
@@ -194,7 +200,7 @@ def __init__(self, api_key: str):
194200
self.client = ElevenLabs(api_key=api_key)
195201
self.logger = logging.getLogger("PodcastGenerator")
196202

197-
def synthesize(self, script_text: str, speaker_mapping: Dict[str, str], output_filepath: str, status_callback=print) -> str:
203+
def synthesize(self, script_text: str, speaker_mapping: Dict[str, str], output_filepath: str, status_callback=print, stop_event: Optional[threading.Event] = None) -> str:
198204
segments = self._parse_script_segments(script_text)
199205
if not segments:
200206
raise ValueError("No valid dialogue segments found in the script. Ensure lines are in 'Speaker: Text' format.")
@@ -220,6 +226,8 @@ def synthesize(self, script_text: str, speaker_mapping: Dict[str, str], output_f
220226

221227
with open(output_filepath, "wb") as f:
222228
for chunk in audio_generator:
229+
if stop_event and stop_event.is_set():
230+
raise Exception("Generation stopped by user during streaming.")
223231
f.write(chunk)
224232

225233
status_callback(f"File saved successfully: {output_filepath}")
@@ -228,8 +236,12 @@ def synthesize(self, script_text: str, speaker_mapping: Dict[str, str], output_f
228236
self.logger.error(f"ElevenLabs API error: {e}")
229237
raise e
230238
except Exception as e:
231-
self.logger.error(f"ElevenLabs critical error: {e}", exc_info=True)
232-
raise Exception(f"An unexpected critical error occurred in ElevenLabs TTS: {e}")
239+
# Don't re-raise the "stopped by user" exception, just let it be handled in the main generate function
240+
if "stopped by user" not in str(e):
241+
self.logger.error(f"ElevenLabs critical error: {e}", exc_info=True)
242+
raise Exception(f"An unexpected critical error occurred in ElevenLabs TTS: {e}")
243+
# Re-raise the stop exception to be caught by the task runner
244+
raise e
233245

234246
def _parse_script_segments(self, script_text: str) -> List[Tuple[str, str]]:
235247
segments = []
@@ -306,11 +318,14 @@ def validate_speakers(script_text: str, app_settings: Dict[str, Any]) -> Tuple[L
306318
return missing_speakers, configured_speakers
307319

308320

309-
def generate(script_text: str, app_settings: dict, output_filepath: str, status_callback=print, api_key: Optional[str] = None, parent_window=None) -> str:
321+
def generate(script_text: str, app_settings: dict, output_filepath: str, status_callback=print, api_key: Optional[str] = None, parent_window=None, stop_event: Optional[threading.Event] = None) -> str:
310322
logger = logging.getLogger("PodcastGenerator")
311323
logger.info("Starting generation function.")
312324
status_callback("Starting podcast generation...")
313325

326+
if stop_event and stop_event.is_set():
327+
raise Exception("Generation stopped by user before starting.")
328+
314329
sanitized_script_text = sanitize_text(script_text)
315330
if not find_ffmpeg_path():
316331
raise FileNotFoundError("FFmpeg executable not found.")
@@ -331,7 +346,7 @@ def generate(script_text: str, app_settings: dict, output_filepath: str, status_
331346
ProviderClass = ElevenLabsTTS if provider_name == "elevenlabs" else GeminiTTS
332347
provider = ProviderClass(api_key=api_key)
333348

334-
return provider.synthesize(script_text=sanitized_script_text, speaker_mapping=speaker_mapping, output_filepath=output_filepath, status_callback=status_callback)
349+
return provider.synthesize(script_text=sanitized_script_text, speaker_mapping=speaker_mapping, output_filepath=output_filepath, status_callback=status_callback, stop_event=stop_event)
335350

336351

337352
def parse_audio_mime_type(mime_type: str) -> Dict[str, int]:

0 commit comments

Comments
 (0)