forked from LoredCast/filewizard
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·3756 lines (3247 loc) · 166 KB
/
main.py
File metadata and controls
executable file
·3756 lines (3247 loc) · 166 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# main.py (merged)
import html
import threading
import logging
import shutil
import subprocess
import traceback
import uuid
import shlex
import yaml
import os
import httpx
import glob
import cv2
import numpy as np
import secrets
import hashlib
from contextlib import asynccontextmanager
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Dict, List, Any, Optional
import resource
from threading import Semaphore
from logging.handlers import RotatingFileHandler
from urllib.parse import urljoin, urlparse
from io import BytesIO
import zipfile
import sys
import re
import importlib
import collections.abc
import time
import ocrmypdf
import pypdf
import pytesseract
from fastapi.middleware.cors import CORSMiddleware
from pytesseract import TesseractNotFoundError
from PIL import Image, UnidentifiedImageError
from faster_whisper import WhisperModel
from fastapi import (Depends, FastAPI, File, Form, HTTPException, Request,
UploadFile, status, Body, WebSocket, Query, WebSocketDisconnect)
from fastapi.concurrency import run_in_threadpool
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from huey import SqliteHuey, crontab
from pydantic import BaseModel, ConfigDict, field_serializer
from sqlalchemy import (Column, DateTime, Integer, String, Text,
create_engine, delete, event, text)
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy.exc import OperationalError
from string import Formatter
from werkzeug.utils import secure_filename
from typing import List as TypingList
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client import OAuth
from dotenv import load_dotenv
from piper import PiperVoice
import wave
import io
import mimetypes
ENABLE_WEBSOCKETS = False
load_dotenv()
# --- Optional Dependency Handling for Piper TTS ---
try:
from piper.synthesis import SynthesisConfig
# download helpers: some piper versions export download_voice, others expose ensure_voice_exists/find_voice
try:
# prefer the more explicit helpers if present
from piper.download import get_voices, ensure_voice_exists, find_voice, VoiceNotFoundError
except Exception:
# fall back to older API if available
try:
from piper.download import get_voices, download_voice, VoiceNotFoundError
ensure_voice_exists = None
find_voice = None
except Exception:
# partial import failed -> treat as piper-not-installed for download helpers
get_voices = None
download_voice = None
ensure_voice_exists = None
find_voice = None
VoiceNotFoundError = None
except ImportError:
SynthesisConfig = None
get_voices = None
download_voice = None
ensure_voice_exists = None
find_voice = None
VoiceNotFoundError = None
try:
from PyPDF2 import PdfMerger
_HAS_PYPDF2 = True
except Exception:
_HAS_PYPDF2 = False
# Instantiate OAuth object (was referenced in code)
oauth = OAuth()
# --------------------------------------------------------------------------------
# --- 1. CONFIGURATION & SECURITY HELPERS
# --------------------------------------------------------------------------------
# --- Path Safety ---
UPLOADS_BASE = Path(os.environ.get("UPLOADS_DIR", "/app/uploads")).resolve()
PROCESSED_BASE = Path(os.environ.get("PROCESSED_DIR", "/app/processed")).resolve()
CHUNK_TMP_BASE = Path(os.environ.get("CHUNK_TMP_DIR", str(UPLOADS_BASE / "tmp"))).resolve()
def ensure_path_is_safe(p: Path, allowed_bases: List[Path]):
"""Enhanced path safety check with traversal prevention"""
try:
# Resolve the path first to get the absolute path
resolved_p = p.resolve()
# Check if resolved path is within allowed base directories
if not any(resolved_p.is_relative_to(base) for base in allowed_bases):
raise ValueError(f"Path {resolved_p} is outside of allowed directories.")
return resolved_p
except Exception as e:
logger = logging.getLogger(__name__)
logger.error(f"Path safety check failed for {p}: {e}")
raise ValueError("Invalid or unsafe path specified.")
def sanitize_filename(filename: str) -> str:
"""Sanitize filename to prevent path traversal and XSS"""
from werkzeug.utils import secure_filename
# Use secure_filename and additional sanitization
safe_name = secure_filename(filename or "")
# Sanitize for HTML output
return html.escape(safe_name)
def sanitize_output(output: str) -> str:
"""Sanitize output to prevent XSS"""
if not output:
return ""
# Limit length and escape HTML
output = output[:2000] # Limit length
return html.escape(output)
def validate_file_type(filename: str, allowed_extensions: set) -> bool:
"""Validate file type by extension"""
if not allowed_extensions: # If set is empty, allow all
return True
return Path(filename).suffix.lower() in allowed_extensions
def get_file_mime_type(filename: str) -> str:
"""Get MIME type from file extension"""
mime_type, _ = mimetypes.guess_type(filename)
return mime_type or "application/octet-stream" # Default to binary if unknown
def get_file_extension(filename: str) -> str:
"""Get file extension in lowercase"""
return Path(filename).suffix.lower()
def get_supported_output_formats_for_file(filename: str, conversion_tools_config: dict) -> list:
"""
Get all supported output formats for a given input file based on its extension
and the supported_input specifications in the tools configuration.
"""
file_ext = get_file_extension(filename)
supported_formats = []
for tool_name, tool_config in conversion_tools_config.items():
supported_inputs = tool_config.get("supported_input", [])
# Convert supported inputs to lowercase for comparison
supported_inputs_lower = [ext.lower() for ext in supported_inputs]
if file_ext in supported_inputs_lower:
# Add all available formats for this tool
for format_key, format_label in tool_config.get("formats", {}).items():
full_format_key = f"{tool_name}_{format_key}"
supported_formats.append({
"value": full_format_key,
"label": f"{tool_config['name']} - {format_label}",
"tool": tool_name,
"format": format_key
})
return supported_formats
# --- Resource Limiting ---
def _limit_resources_preexec():
"""Set resource limits for child processes to prevent DoS attacks."""
try:
# 6000s CPU, 4GB address space
resource.setrlimit(resource.RLIMIT_CPU, (6000, 6000))
resource.setrlimit(resource.RLIMIT_AS, (4 * 1024 * 1024 * 1024, 4 * 1024 * 1024 * 1024))
except Exception as e:
# This may fail in some environments (e.g. Windows, some containers)
logging.getLogger(__name__).warning(f"Could not set resource limits: {e}")
pass
# --- Model concurrency semaphore (lazily initialized) ---
_model_semaphore: Optional[Semaphore] = None
def get_model_semaphore() -> Semaphore:
"""Lazily initializes and returns the global model semaphore."""
global _model_semaphore
if _model_semaphore is None:
# Read from app config, fall back to env var, then to a hardcoded default of 1
model_concurrency_from_env = int(os.environ.get("MODEL_CONCURRENCY", "1"))
model_concurrency = APP_CONFIG.get("app_settings", {}).get("model_concurrency", model_concurrency_from_env)
_model_semaphore = Semaphore(model_concurrency)
logger.info(f"Model concurrency semaphore initialized with limit: {model_concurrency}")
return _model_semaphore
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
_log_handler = RotatingFileHandler("app.log", maxBytes=10*1024*1024, backupCount=1)
_log_formatter = logging.Formatter('%(asctime)s %(levelname)s %(name)s %(message)s')
_log_handler.setFormatter(_log_formatter)
logging.getLogger().addHandler(_log_handler)
logger = logging.getLogger(__name__)
# --- Environment Mode ---
LOCAL_ONLY_MODE = os.getenv('LOCAL_ONLY', 'True').lower() in ('true', '1', 't')
if LOCAL_ONLY_MODE:
logger.warning("Authentication is DISABLED. Running in LOCAL_ONLY mode.")
class AppPaths(BaseModel):
BASE_DIR: Path = Path(__file__).resolve().parent
UPLOADS_DIR: Path = UPLOADS_BASE
PROCESSED_DIR: Path = PROCESSED_BASE
CHUNK_TMP_DIR: Path = CHUNK_TMP_BASE
TTS_MODELS_DIR: Path = BASE_DIR / "models" / "tts"
KOKORO_TTS_MODELS_DIR: Path = BASE_DIR / "models" / "tts" / "kokoro"
KOKORO_MODEL_FILE: Path = KOKORO_TTS_MODELS_DIR / "kokoro-v1.0.onnx"
KOKORO_VOICES_FILE: Path = KOKORO_TTS_MODELS_DIR / "voices-v1.0.bin"
DATABASE_URL: str = f"sqlite:///{BASE_DIR / 'jobs.db'}"
HUEY_DB_PATH: str = str(BASE_DIR / "huey.db")
CONFIG_DIR: Path = BASE_DIR / "config"
SETTINGS_FILE: Path = CONFIG_DIR / "settings.yml"
DEFAULT_SETTINGS_FILE: Path = BASE_DIR / "settings.default.yml"
PATHS = AppPaths()
APP_CONFIG: Dict[str, Any] = {}
PATHS.UPLOADS_DIR.mkdir(exist_ok=True, parents=True)
PATHS.PROCESSED_DIR.mkdir(exist_ok=True, parents=True)
PATHS.CHUNK_TMP_DIR.mkdir(exist_ok=True, parents=True)
PATHS.CONFIG_DIR.mkdir(exist_ok=True, parents=True)
PATHS.TTS_MODELS_DIR.mkdir(exist_ok=True, parents=True)
PATHS.KOKORO_TTS_MODELS_DIR.mkdir(exist_ok=True, parents=True)
# --- WebSocket Connection Manager ---
import json
import asyncio
import threading
from typing import Dict, List
from collections import defaultdict
import time
class ConnectionManager:
def __init__(self):
# Maps user_id to list of WebSocket connections
self.active_connections: Dict[str, List[WebSocket]] = defaultdict(list)
# Maps WebSocket to user_id
self.connection_to_user: Dict[WebSocket, str] = {}
# Maps WebSocket to connection metadata
self.connection_metadata: Dict[WebSocket, dict] = {}
# Logger
self.logger = logging.getLogger(__name__)
async def connect(self, websocket: WebSocket, user_id: str, connection_id: str = None):
await websocket.accept()
self.connection_to_user[websocket] = user_id
self.connection_metadata[websocket] = {"connection_id": connection_id or str(uuid.uuid4())}
self.active_connections[user_id].append(websocket)
def disconnect(self, websocket: WebSocket):
user_id = self.connection_to_user.pop(websocket, None)
if user_id and websocket in self.active_connections[user_id]:
self.active_connections[user_id].remove(websocket)
self.connection_metadata.pop(websocket, None)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast_user_jobs(self, user_id: str, message: str):
"""Send message to all connections for a specific user"""
self.logger.debug(f"Broadcasting message to user {user_id}: {message}")
if user_id in self.active_connections:
disconnected = []
sent_count = 0
for websocket in self.active_connections[user_id]:
try:
await websocket.send_text(message)
sent_count += 1
except WebSocketDisconnect:
disconnected.append(websocket)
# Remove disconnected connections
for websocket in disconnected:
self.disconnect(websocket)
if sent_count > 0:
self.logger.info(f"Sent WebSocket message to {sent_count} connections for user {user_id}")
else:
self.logger.info(f"No active connections for user {user_id}")
async def broadcast_job_status_update(self, user_id: str, job_data: dict):
"""Send job status update to user's connections"""
logger = logging.getLogger(__name__)
logger.info(f"Broadcasting job update to user {user_id}: job_id={job_data.get('id')}, status={job_data.get('status')}")
message = json.dumps({
"type": "job_update",
"job": job_data,
"timestamp": datetime.now(timezone.utc).isoformat()
})
await self.broadcast_user_jobs(user_id, message)
logger.info(f"Finished broadcasting job update to user {user_id}")
async def broadcast_multiple_jobs_update(self, user_id: str, jobs_data: List):
"""Send multiple job updates to user's connections"""
message = json.dumps({
"type": "batch_job_update",
"jobs": jobs_data,
"timestamp": datetime.now(timezone.utc).isoformat()
})
await self.broadcast_user_jobs(user_id, message)
def sync_broadcast_job_status_update(self, user_id: str, job_data: dict):
"""Synchronously broadcast job status update - for use from sync contexts like Huey tasks"""
logger = logging.getLogger(__name__)
job_id = job_data.get('id')
status = job_data.get('status')
progress = job_data.get('progress')
logger.info(f"Queueing WebSocket notification for user {user_id}: job_id={job_id}, status={status}, progress={progress}")
try:
db = SessionLocal()
notification = Notification(
user_id=user_id,
job_data=json.dumps(job_data)
)
db.add(notification)
db.commit()
logger.info(f"Queued WebSocket notification for user {user_id}, job {job_id}")
except Exception as e:
logger.warning(f"Could not queue WebSocket notification for user {user_id}, job {job_id}: {e}")
finally:
if db:
db.close()
async def process_notification_queue(self):
"""Process queued notifications and send them to WebSocket clients"""
db = SessionLocal()
claimed_notification_data = None
try:
# Find a notification to process
notification_to_process = db.query(Notification).order_by(Notification.created_at).first()
if not notification_to_process:
return
# Try to "claim" it by deleting it.
notification_id = notification_to_process.id
# We need to copy the data before deleting.
claimed_notification_data = {
"id": notification_id,
"user_id": notification_to_process.user_id,
"job_data": notification_to_process.job_data
}
deleted_count = db.query(Notification).filter_by(id=notification_id).delete(synchronize_session=False)
db.commit()
if deleted_count == 0:
# Another worker got it first.
self.logger.debug(f"Notification {notification_id} was already claimed by another worker.")
claimed_notification_data = None # Do not process
except Exception as e:
self.logger.error(f"Error claiming notification from DB: {e}")
db.rollback()
claimed_notification_data = None # Do not process
finally:
db.close()
# --- Process the claimed notification outside the DB transaction ---
if claimed_notification_data:
try:
user_id = claimed_notification_data["user_id"]
notification_id = claimed_notification_data["id"]
self.logger.debug(f"Processing claimed notification {notification_id} for user {user_id}")
if user_id in self.active_connections:
job_data = json.loads(claimed_notification_data["job_data"])
message = json.dumps({
"type": "job_update",
"job": job_data,
"timestamp": datetime.now(timezone.utc).isoformat()
})
await self.broadcast_user_jobs(user_id, message)
self.logger.info(f"Sent claimed notification {notification_id} to user {user_id}")
except Exception as e:
self.logger.warning(f"Error sending claimed notification {claimed_notification_data['id']}: {e}")
# Initialize manager
manager = ConnectionManager()
def deep_merge(source: dict, dest: dict) -> dict:
"""
Recursively merges source dict into dest dict. Modifies dest in place.
"""
for key, value in source.items():
if isinstance(value, dict) and key in dest and isinstance(dest[key], dict):
deep_merge(value, dest[key])
else:
dest[key] = value
return dest
def initialize_settings_file():
"""
Ensures that config/settings.yml exists. If not, it copies it from
settings.default.yml.
"""
if not PATHS.SETTINGS_FILE.exists():
logger.info(f"'{PATHS.SETTINGS_FILE}' not found. Copying from '{PATHS.DEFAULT_SETTINGS_FILE}'.")
try:
shutil.copy(PATHS.DEFAULT_SETTINGS_FILE, PATHS.SETTINGS_FILE)
except FileNotFoundError:
logger.error(f"CRITICAL: Default settings file '{PATHS.DEFAULT_SETTINGS_FILE}' not found. Cannot initialize settings.")
PATHS.SETTINGS_FILE.touch()
except Exception as e:
logger.error(f"CRITICAL: Failed to copy default settings file: {e}")
PATHS.SETTINGS_FILE.touch()
def load_app_config():
"""
Loads configuration by deeply merging settings from hardcoded defaults,
settings.default.yml, and settings.yml, then applies environment variable
overrides.
"""
global APP_CONFIG
# --- 1. Hardcoded Defaults ---
hardcoded_defaults = {
"app_settings": {"max_file_size_mb": 100, "allowed_all_extensions": [], "app_public_url": ""},
"transcription_settings": {"whisper": {"allowed_models": ["tiny", "base", "small"], "compute_type": "int8", "device": "cpu"}},
"tts_settings": {
"piper": {"model_dir": str(PATHS.TTS_MODELS_DIR), "use_cuda": False, "synthesis_config": {"length_scale": 1.0, "noise_scale": 0.667, "noise_w": 0.8}},
"kokoro": {"model_dir": str(PATHS.KOKORO_TTS_MODELS_DIR), "command_template": "kokoro-tts {input} {output} --model {model_path} --voices {voices_path} --lang {lang} --voice {model_name}"}
},
"conversion_tools": {},
"ocr_settings": {"ocrmypdf": {}},
"auth_settings": {"oidc_client_id": "", "oidc_client_secret": "", "oidc_server_metadata_url": "", "admin_users": []},
"webhook_settings": {"enabled": False, "allow_chunked_api_uploads": False, "allowed_callback_urls": [], "callback_bearer_token": ""}
}
config = hardcoded_defaults
# --- 2. Merge settings.default.yml ---
try:
with open(PATHS.DEFAULT_SETTINGS_FILE, 'r', encoding='utf8') as f:
default_cfg = yaml.safe_load(f) or {}
config = deep_merge(default_cfg, config)
except (FileNotFoundError, yaml.YAMLError) as e:
logger.warning(f"Could not load or parse settings.default.yml: {e}. Using hardcoded defaults.")
# --- 3. Merge settings.yml ---
try:
with open(PATHS.SETTINGS_FILE, 'r', encoding='utf8') as f:
user_cfg = yaml.safe_load(f) or {}
config = deep_merge(user_cfg, config)
except (FileNotFoundError, yaml.YAMLError):
# This is not an error, just means user is using defaults
pass
# --- 4. Environment Variable Overrides for Transcription ---
# Safely access nested keys
trans_settings = config.get("transcription_settings", {}).get("whisper", {})
transcription_device = os.environ.get("TRANSCRIPTION_DEVICE", trans_settings.get("device", "cpu"))
default_compute_type = "float16" if transcription_device == "cuda" else "int8"
transcription_compute_type = os.environ.get("TRANSCRIPTION_COMPUTE_TYPE", trans_settings.get("compute_type", default_compute_type))
transcription_device_index_str = os.environ.get("TRANSCRIPTION_DEVICE_INDEX", "0")
try:
if ',' in transcription_device_index_str:
transcription_device_index = [int(i.strip()) for i in transcription_device_index_str.split(',')]
else:
transcription_device_index = int(transcription_device_index_str)
except ValueError:
logger.warning(f"Invalid TRANSCRIPTION_DEVICE_INDEX value: '{transcription_device_index_str}'. Defaulting to 0.")
transcription_device_index = 0
config.setdefault("transcription_settings", {}).setdefault("whisper", {})
config["transcription_settings"]["whisper"]["device"] = transcription_device
config["transcription_settings"]["whisper"]["compute_type"] = transcription_compute_type
config["transcription_settings"]["whisper"]["device_index"] = transcription_device_index
# --- 5. Final Processing & Assignment ---
app_settings = config.get("app_settings", {})
max_mb = app_settings.get("max_file_size_mb", 100)
app_settings["max_file_size_bytes"] = int(max_mb) * 1024 * 1024
allowed = app_settings.get("allowed_all_extensions", [])
if not isinstance(allowed, (list, set)):
allowed = []
app_settings["allowed_all_extensions"] = set(allowed)
config["app_settings"] = app_settings
APP_CONFIG = config
logger.info("Application configuration loaded.")
# --------------------------------------------------------------------------------
# --- 2. DATABASE & Schemas
# --------------------------------------------------------------------------------
engine = create_engine(
PATHS.DATABASE_URL,
connect_args={"check_same_thread": False, "timeout": 30},
poolclass=NullPool,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
@event.listens_for(engine, "connect")
def _set_sqlite_pragmas(dbapi_connection, connection_record):
c = dbapi_connection.cursor()
try:
c.execute("PRAGMA journal_mode=WAL;")
c.execute("PRAGMA synchronous=NORMAL;")
finally:
c.close()
class Job(Base):
__tablename__ = "jobs"
id = Column(String, primary_key=True, index=True)
user_id = Column(String, index=True, nullable=True)
parent_job_id = Column(String, index=True, nullable=True)
task_type = Column(String, index=True)
status = Column(String, default="pending")
progress = Column(Integer, default=0)
original_filename = Column(String)
input_filepath = Column(String)
input_filesize = Column(Integer, nullable=True)
processed_filepath = Column(String, nullable=True)
output_filesize = Column(Integer, nullable=True)
result_preview = Column(Text, nullable=True)
error_message = Column(Text, nullable=True)
callback_url = Column(String, nullable=True)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class JobCreate(BaseModel):
id: str
user_id: str | None = None
parent_job_id: str | None = None
task_type: str
original_filename: str
input_filepath: str
input_filesize: int | None = None
callback_url: str | None = None
processed_filepath: str | None = None
class JobSchema(BaseModel):
id: str
parent_job_id: str | None = None
task_type: str
status: str
progress: int
original_filename: str
input_filesize: int | None = None
output_filesize: int | None = None
processed_filepath: str | None = None
result_preview: str | None = None
error_message: str | None = None
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
@field_serializer('created_at', 'updated_at')
def serialize_dt(self, dt: datetime, _info):
return dt.isoformat() + "Z"
class FinalizeUploadPayload(BaseModel):
upload_id: str
original_filename: str
total_chunks: int
task_type: str
model_size: str = ""
model_name: str = ""
output_format: str = ""
generate_timestamps: bool = False
callback_url: Optional[str] = None # For API chunked uploads
class JobSelection(BaseModel):
job_ids: List[str]
class Notification(Base):
__tablename__ = "notifications"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(String, index=True, nullable=False)
job_data = Column(Text, nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
# --------------------------------------------------------------------------------
# --- 3. CRUD OPERATIONS & WEBHOOKS
# --------------------------------------------------------------------------------
def get_job(db: Session, job_id: str):
# return db.query(Job).filter(Job.id == job_id).first()
return db.query(Job).filter(Job.id == job_id).first()
def get_jobs(db: Session, user_id: str | None = None, skip: int = 0, limit: int = 100):
query = db.query(Job)
if user_id:
query = query.filter(Job.user_id == user_id)
return query.order_by(Job.created_at.desc()).offset(skip).limit(limit).all()
def create_job(db: Session, job: JobCreate):
db_job = Job(**job.model_dump())
db.add(db_job)
db.commit()
db.refresh(db_job)
# Broadcast the new job to UI clients via Huey task
job_schema = JobSchema.model_validate(db_job)
return db_job
def update_job_status(db: Session, job_id: str, status: str, progress: int = None, error: str = None):
db_job = get_job(db, job_id)
if db_job:
old_status = db_job.status
old_progress = db_job.progress
db_job.status = status
if progress is not None:
db_job.progress = progress
if error:
db_job.error_message = error
status_changed = old_status != status
progress_changed = progress is not None and old_progress != progress
db.commit()
job_schema = JobSchema.model_validate(db_job)
if (status_changed or progress_changed) and db_job.user_id:
manager.sync_broadcast_job_status_update(db_job.user_id, job_schema.model_dump())
return db_job
def mark_job_as_completed(db: Session, job_id: str, output_filepath_str: str | None = None, preview: str | None = None):
db_job = get_job(db, job_id)
if db_job and db_job.status != 'cancelled':
if preview:
db_job.result_preview = preview.strip()[:2000]
if output_filepath_str:
try:
output_path = Path(output_filepath_str)
if output_path.exists():
db_job.output_filesize = output_path.stat().st_size
except Exception:
logger.exception(f"Could not stat output file {output_filepath_str} for job {job_id}")
update_job_status(db, job_id, "completed", progress=100)
return db_job
def send_webhook_notification(job_id: str, app_config: Dict[str, Any], base_url: str):
"""Sends a notification to the callback URL if one is configured for the job."""
webhook_config = app_config.get("webhook_settings", {})
if not webhook_config.get("enabled", False):
return
db = SessionLocal()
try:
job = get_job(db, job_id)
if not job or not job.callback_url:
return
download_url = None
if job.status == "completed" and job.processed_filepath:
filename = Path(job.processed_filepath).name
public_url = app_config.get("app_settings", {}).get("app_public_url", base_url)
if not public_url:
logger.warning(f"app_public_url is not set. Cannot generate a full download URL for job {job_id}.")
download_url = f"/download/{filename}" # Relative URL as fallback
else:
download_url = urljoin(public_url, f"/download/{filename}")
payload = {
"job_id": job.id,
"status": job.status,
"original_filename": job.original_filename,
"download_url": download_url,
"error_message": job.error_message,
"created_at": job.created_at.isoformat() + "Z",
"updated_at": job.updated_at.isoformat() + "Z",
}
headers = {"Content-Type": "application/json", "User-Agent": "FileProcessor-Webhook/1.0"}
token = webhook_config.get("callback_bearer_token")
if token:
headers["Authorization"] = f"Bearer {token}"
try:
with httpx.Client() as client:
response = client.post(job.callback_url, json=payload, headers=headers, timeout=15)
response.raise_for_status()
logger.info(f"Sent webhook notification for job {job_id} to {job.callback_url} (Status: {response.status_code})")
except httpx.RequestError as e:
logger.error(f"Failed to send webhook for job {job_id} to {job.callback_url}: {e}")
except httpx.HTTPStatusError as e:
logger.error(f"Webhook for job {job_id} received non-2xx response {e.response.status_code} from {job.callback_url}")
except Exception as e:
logger.exception(f"An unexpected error occurred in send_webhook_notification for job {job_id}: {e}")
finally:
db.close()
# --------------------------------------------------------------------------------
# --- 4. BACKGROUND TASK SETUP
# --------------------------------------------------------------------------------
huey = SqliteHuey(filename=PATHS.HUEY_DB_PATH)
WHISPER_MODELS_CACHE: Dict[str, WhisperModel] = {}
PIPER_VOICES_CACHE: Dict[str, "PiperVoice"] = {}
AVAILABLE_TTS_VOICES_CACHE: Dict[str, Any] | None = None
WHISPER_MODELS_LAST_USED: Dict[str, float] = {}
# --- Cache Eviction Settings ---
_cache_cleanup_thread: Optional[threading.Thread] = None
_cache_lock = threading.Lock() # Global lock for modifying cache dictionaries
_model_locks: Dict[str, threading.Lock] = {}
_global_lock = threading.Lock() # Lock for initializing model-specific locks
def _whisper_cache_cleanup_worker():
"""
Periodically checks for and unloads Whisper models that have been inactive.
The timeout and check interval are configured in the application settings.
"""
while True:
# Read settings within the loop to allow for live changes
app_settings = APP_CONFIG.get("app_settings", {})
check_interval = app_settings.get("cache_check_interval", 300)
inactivity_timeout = app_settings.get("model_inactivity_timeout", 1800)
time.sleep(check_interval)
with _cache_lock:
# Create a copy of items to avoid issues with modifying dict while iterating
expired_models = []
for model_size, last_used in WHISPER_MODELS_LAST_USED.items():
if time.time() - last_used > inactivity_timeout:
expired_models.append(model_size)
if not expired_models:
continue
logger.info(f"Found {len(expired_models)} inactive Whisper models to unload: {expired_models}")
for model_size in expired_models:
# Acquire the specific model lock before removing to prevent race conditions
model_lock = _get_or_create_model_lock(model_size)
with model_lock:
# Check if the model is still in the cache (it should be)
if model_size in WHISPER_MODELS_CACHE:
logger.info(f"Unloading inactive Whisper model: {model_size}")
# Remove from caches
model_to_unload = WHISPER_MODELS_CACHE.pop(model_size, None)
WHISPER_MODELS_LAST_USED.pop(model_size, None)
# Explicitly delete the object to encourage garbage collection
if model_to_unload:
del model_to_unload
# Explicitly run garbage collection outside the main lock
import gc
gc.collect()
def get_whisper_model(model_size: str, whisper_settings: dict) -> Any:
# Fast path: check cache. If hit, update timestamp and return.
with _cache_lock:
if model_size in WHISPER_MODELS_CACHE:
logger.debug(f"Cache hit for model '{model_size}'")
WHISPER_MODELS_LAST_USED[model_size] = time.time()
return WHISPER_MODELS_CACHE[model_size]
# Model not in cache, prepare for loading.
model_lock = _get_or_create_model_lock(model_size)
with model_lock:
# Re-check cache inside lock in case another thread loaded it
with _cache_lock:
if model_size in WHISPER_MODELS_CACHE:
WHISPER_MODELS_LAST_USED[model_size] = time.time()
return WHISPER_MODELS_CACHE[model_size]
logger.info(f"Loading Whisper model '{model_size}'...")
try:
device = whisper_settings.get("device", "cpu")
compute_type = whisper_settings.get("compute_type", "int8")
device_index = whisper_settings.get("device_index", 0)
model = WhisperModel(
model_size,
device=device,
device_index=device_index,
compute_type=compute_type,
cpu_threads=max(1, os.cpu_count() // 2),
num_workers=1
)
# Add the new model to the cache under lock
with _cache_lock:
WHISPER_MODELS_CACHE[model_size] = model
WHISPER_MODELS_LAST_USED[model_size] = time.time()
logger.info(f"Model '{model_size}' loaded (device={device}, compute={compute_type})")
return model
except Exception as e:
logger.error(f"Model '{model_size}' failed to load: {str(e)}", exc_info=True)
raise RuntimeError(f"Whisper model initialization failed: {e}") from e
def _get_or_create_model_lock(model_size: str) -> threading.Lock:
"""Thread-safe lock acquisition with minimal global contention"""
# Fast path: lock already exists
if model_size in _model_locks:
return _model_locks[model_size]
# Slow path: create lock under global lock
with _global_lock:
return _model_locks.setdefault(model_size, threading.Lock())
def get_piper_voice(model_name: str, tts_settings: dict | None) -> "PiperVoice":
"""
Load (or download + load) a Piper voice in a robust way:
- Try Python API helpers (get_voices, ensure_voice_exists/find_voice, download_voice)
- On any failure, try CLI fallback (download_voice_cli)
- Attempt to locate model files after download (search subdirs)
- Try re-importing piper if bindings were previously unavailable
"""
# ----- Defensive normalization -----
if tts_settings is None or not isinstance(tts_settings, dict):
logger.debug("get_piper_voice: normalizing tts_settings (was %r)", tts_settings)
tts_settings = {}
model_dir_val = tts_settings.get("model_dir", None)
if model_dir_val is None:
model_dir = Path(str(PATHS.TTS_MODELS_DIR))
else:
try:
model_dir = Path(model_dir_val)
except Exception:
logger.warning("Could not coerce tts_settings['model_dir']=%r to Path; using default.", model_dir_val)
model_dir = Path(str(PATHS.TTS_MODELS_DIR))
model_dir.mkdir(parents=True, exist_ok=True)
# If PiperVoice already cached, reuse
if model_name in PIPER_VOICES_CACHE:
logger.info("Reusing cached Piper voice '%s'.", model_name)
return PIPER_VOICES_CACHE[model_name]
with get_model_semaphore():
if model_name in PIPER_VOICES_CACHE:
return PIPER_VOICES_CACHE[model_name]
# If Python bindings are missing, attempt CLI download first (and try re-import)
if PiperVoice is None:
logger.info("Piper Python bindings missing; attempting CLI download fallback for '%s' before failing import.", model_name)
cli_ok = False
try:
cli_ok = download_voice_cli(model_name, model_dir)
except Exception as e:
logger.warning("CLI download attempt raised: %s", e)
cli_ok = False
if cli_ok:
# attempt to re-import piper package (maybe import issue was transient)
try:
importlib.invalidate_caches()
piper_mod = importlib.import_module("piper")
from piper import PiperVoice as _PiperVoice # noqa: F401
from piper.synthesis import SynthesisConfig as _SynthesisConfig # noqa: F401
globals().update({"PiperVoice": _PiperVoice, "SynthesisConfig": _SynthesisConfig})
logger.info("Successfully re-imported piper after CLI download.")
except Exception:
logger.warning("Could not import piper after CLI download; bindings still unavailable.")
# If bindings still absent, we cannot load models; raise helpful error
if PiperVoice is None:
raise RuntimeError(
"Piper Python bindings are not installed or failed to import. "
"Tried CLI download fallback but python bindings are still unavailable. "
"Please install 'piper-tts' in the runtime used by this process."
)
# Now we have Piper bindings (or they were present to begin with). Attempt Python helpers.
onnx_path = None
config_path = None
# Prefer using get_voices to update the index if available
voices_info = None
try:
if get_voices:
try:
voices_info = get_voices(str(model_dir), update_voices=True)
except TypeError:
# some versions may not support update_voices kwarg
voices_info = get_voices(str(model_dir))
except Exception as e:
logger.debug("get_voices failed or unavailable: %s", e)
voices_info = None
try:
# Preferred modern helpers
if ensure_voice_exists and find_voice:
try:
ensure_voice_exists(model_name, [model_dir], model_dir, voices_info)
onnx_path, config_path = find_voice(model_name, [model_dir])
except Exception as e:
# Could be VoiceNotFoundError or other download error
logger.warning("ensure/find voice failed for %s: %s", model_name, e)
raise
elif download_voice:
# older API: call download helper directly
try:
download_voice(model_name, model_dir)
# attempt to locate files
onnx_path = model_dir / f"{model_name}.onnx"
config_path = model_dir / f"{model_name}.onnx.json"
except Exception:
logger.warning("download_voice failed for %s", model_name)
raise
else:
# No python download helper available
raise RuntimeError("No Python download helper available in installed piper package.")
except Exception as py_exc:
# Python helper route failed; try CLI fallback BEFORE giving up
logger.info("Python download route failed for '%s' (%s). Trying CLI fallback...", model_name, py_exc)
try:
cli_ok = download_voice_cli(model_name, model_dir)
except Exception as e:
logger.warning("CLI fallback attempt raised: %s", e)
cli_ok = False
if not cli_ok:
# If CLI also failed, re-raise the original python exception to preserve context
logger.error("Both Python download helpers and CLI fallback failed for '%s'.", model_name)
raise
# CLI succeeded (or at least returned success) — try to find files on disk
onnx_path, config_path = _find_model_files(model_name, model_dir)
if not (onnx_path and config_path):
# maybe CLI wrote into a nested dir or different name; try to search broadly
logger.info("Could not find model files after CLI download in %s; attempting broader search...", model_dir)
onnx_path, config_path = _find_model_files(model_name, model_dir)
if not (onnx_path and config_path):
logger.error("Model files still missing after CLI fallback for '%s'.", model_name)
raise RuntimeError(f"Piper voice files for '{model_name}' missing after CLI fallback.")
# continue to loading below
# Final safety check and last-resort search
if not (onnx_path and config_path):
onnx_path, config_path = _find_model_files(model_name, model_dir)
if not (onnx_path and config_path):
raise RuntimeError(f"Piper voice files for '{model_name}' are missing after attempts to download.")
# Load the PiperVoice
try:
use_cuda = bool(tts_settings.get("use_cuda", False))
voice = PiperVoice.load(str(onnx_path), config_path=str(config_path), use_cuda=use_cuda)
PIPER_VOICES_CACHE[model_name] = voice
logger.info("Loaded Piper voice '%s' from %s", model_name, onnx_path)
return voice
except Exception as e:
logger.exception("Failed to load Piper voice '%s' from files (%s, %s): %s", model_name, onnx_path, config_path, e)
raise
def _find_model_files(model_name: str, model_dir: Path):
"""
Try multiple strategies to find onnx and config files for a given model_name under model_dir.
Returns (onnx_path, config_path) or (None, None).
"""
# direct files in model_dir
onnx = model_dir / f"{model_name}.onnx"