diff --git a/app.py b/app.py index 4752e66e..cac00d8b 100644 --- a/app.py +++ b/app.py @@ -331,7 +331,7 @@ def open_folder(folder_path: str): parser = argparse.ArgumentParser() -parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER.value, +parser.add_argument('--whisper_type', type=str, default=WhisperImpl.WHISPER.value, choices=[item.value for item in WhisperImpl], help='A type of the whisper implementation (Github repo name)') parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value') diff --git a/backend/common/audio.py b/backend/common/audio.py index a84c66e5..9b37319f 100644 --- a/backend/common/audio.py +++ b/backend/common/audio.py @@ -1,7 +1,7 @@ from io import BytesIO import numpy as np import httpx -import faster_whisper +from whisper import audio as whisper_audio from pydantic import BaseModel from fastapi import ( HTTPException, @@ -31,6 +31,6 @@ async def read_audio( raise HTTPException(status_code=422, detail="Could not download the file") file_content = file_response.content file_bytes = BytesIO(file_content) - audio = faster_whisper.audio.decode_audio(file_bytes) + audio = whisper_audio.decode_audio(file_bytes) duration = len(audio) / 16000 return audio, AudioInfo(duration=duration) diff --git a/backend/routers/transcription/router.py b/backend/routers/transcription/router.py index 11cad9d1..7f4ee70e 100644 --- a/backend/routers/transcription/router.py +++ b/backend/routers/transcription/router.py @@ -12,7 +12,7 @@ from datetime import datetime from modules.whisper.data_classes import * from modules.utils.paths import BACKEND_CACHE_DIR -from modules.whisper.faster_whisper_inference import FasterWhisperInference +from modules.whisper.whisper_Inference import WhisperInference from backend.common.audio import read_audio from backend.common.models import QueueResponse from backend.common.config_loader import load_server_config @@ -41,9 +41,9 @@ def progress_callback(progress_value: float): @functools.lru_cache -def get_pipeline() -> 'FasterWhisperInference': +def get_pipeline() -> 'WhisperInference': config = load_server_config()["whisper"] - inferencer = FasterWhisperInference( + inferencer = WhisperInference( output_dir=BACKEND_CACHE_DIR ) inferencer.update_model( diff --git a/modules/utils/audio_manager.py b/modules/utils/audio_manager.py index d0e0b998..70f0677c 100644 --- a/modules/utils/audio_manager.py +++ b/modules/utils/audio_manager.py @@ -2,7 +2,7 @@ import soundfile as sf import os import numpy as np -from faster_whisper.audio import decode_audio +from whisper.audio import decode_audio from modules.utils.files_manager import is_video from modules.utils.logger import get_logger diff --git a/modules/utils/logger.py b/modules/utils/logger.py index 589fdc66..30772e6c 100644 --- a/modules/utils/logger.py +++ b/modules/utils/logger.py @@ -14,9 +14,12 @@ def get_logger(name: Optional[str] = None): "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) - handler = logging.StreamHandler() - # handler.setFormatter(formatter) + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) - logger.addHandler(handler) + file_handler = logging.FileHandler("webui.log") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) return logger \ No newline at end of file diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index 259e4e33..3aeb694d 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -117,6 +117,8 @@ def run(self, """ start_time = time.time() + logger.info(f"Run started. audio={audio}") + if not validate_audio(audio): return [Segment()], 0 @@ -125,6 +127,7 @@ def run(self, bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization if bgm_params.is_separate_bgm: + logger.info("Starting background music separation") music, audio, _ = self.music_separator.separate( audio=audio, model_name=bgm_params.uvr_model_size, @@ -145,10 +148,12 @@ def run(self, if bgm_params.enable_offload: self.music_separator.offload() elapsed_time_bgm_sep = time.time() - start_time + logger.info("BGM separation completed") origin_audio = deepcopy(audio) if vad_params.vad_filter: + logger.info("Applying VAD") progress(0, desc="Filtering silent parts from audio..") vad_options = VadOptions( threshold=vad_params.threshold, @@ -168,7 +173,9 @@ def run(self, audio = vad_processed else: vad_params.vad_filter = False + logger.info("VAD completed") + logger.info("Starting transcription") result, elapsed_time_transcription = self.transcribe( audio, progress, @@ -177,6 +184,7 @@ def run(self, ) if whisper_params.enable_offload: self.offload() + logger.info("Transcription completed") if vad_params.vad_filter: restored_result = self.vad.restore_speech_timestamps( @@ -189,6 +197,7 @@ def run(self, logger.info("VAD detected no speech segments in the audio.") if diarization_params.is_diarize: + logger.info("Running diarization") progress(0.99, desc="Diarizing speakers..") result, elapsed_time_diarization = self.diarizer.run( audio=origin_audio, @@ -198,6 +207,7 @@ def run(self, ) if diarization_params.enable_offload: self.diarizer.offload() + logger.info("Diarization completed") self.cache_parameters( params=params, @@ -261,6 +271,8 @@ def transcribe_file(self, "highlight_words": True if params.whisper.word_timestamps else False } + logger.info(f"Transcribing files: {files} input_folder={input_folder_path}") + if input_folder_path: files = get_media_files(input_folder_path, include_sub_directory=include_subdirectory) if isinstance(files, str): @@ -270,6 +282,7 @@ def transcribe_file(self, files_info = {} for file in files: + logger.info(f"Processing file {file}") transcribed_segments, time_for_task = self.run( file, progress, @@ -312,6 +325,7 @@ def transcribe_file(self, result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}" result_file_path = [info['path'] for info in files_info.values()] + logger.info("File transcription completed") return result_str, result_file_path except Exception as e: diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index dd9f0d96..97112a59 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -368,7 +368,7 @@ def to_gradio_inputs(cls, available_langs: Optional[List] = None, available_compute_types: Optional[List] = None, compute_type: Optional[str] = None): - whisper_type = WhisperImpl.FASTER_WHISPER.value if whisper_type is None else whisper_type.strip().lower() + whisper_type = WhisperImpl.WHISPER.value if whisper_type is None else whisper_type.strip().lower() inputs = [] if not only_advanced: diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index 181222bd..d73ea70d 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -64,12 +64,14 @@ def transcribe(self, elapsed_time: float elapsed time for transcription """ + logger.info("Transcribe called") start_time = time.time() params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) + logger.info(f"Model set to {params.model_size} ({params.compute_type})") segments, info = self.model.transcribe( audio=audio, @@ -112,6 +114,7 @@ def transcribe(self, segments_result.append(Segment.from_faster_whisper(segment)) elapsed_time = time.time() - start_time + logger.info(f"Transcription done in {elapsed_time:.2f}s") return segments_result, elapsed_time def update_model(self, @@ -133,6 +136,7 @@ def update_model(self, progress: gr.Progress Indicator to show progress directly in gradio. """ + logger.info(f"Loading model {model_size} with compute_type {compute_type}") progress(0, desc="Initializing Model..") model_size_dirname = model_size.replace("/", "--") if "/" in model_size else model_size @@ -163,6 +167,7 @@ def update_model(self, compute_type=self.current_compute_type, local_files_only=local_files_only ) + logger.info("Model initialized") def get_model_paths(self): """ diff --git a/modules/whisper/whisper_factory.py b/modules/whisper/whisper_factory.py index 1e9c85da..b4710799 100644 --- a/modules/whisper/whisper_factory.py +++ b/modules/whisper/whisper_factory.py @@ -91,8 +91,8 @@ def create_whisper_inference( uvr_model_dir=uvr_model_dir ) else: - return FasterWhisperInference( - model_dir=faster_whisper_model_dir, + return WhisperInference( + model_dir=whisper_model_dir, output_dir=output_dir, diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir diff --git a/tests/test_bgm_separation.py b/tests/test_bgm_separation.py index b8e3a871..7f5b528c 100644 --- a/tests/test_bgm_separation.py +++ b/tests/test_bgm_separation.py @@ -18,7 +18,6 @@ "whisper_type,vad_filter,bgm_separation,diarization", [ (WhisperImpl.WHISPER.value, False, True, False), - (WhisperImpl.FASTER_WHISPER.value, False, True, False), (WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False) ] ) @@ -39,7 +38,6 @@ def test_bgm_separation_pipeline( "whisper_type,vad_filter,bgm_separation,diarization", [ (WhisperImpl.WHISPER.value, True, True, False), - (WhisperImpl.FASTER_WHISPER.value, True, True, False), (WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False) ] ) diff --git a/tests/test_diarization.py b/tests/test_diarization.py index 0e647fa3..46d01cd0 100644 --- a/tests/test_diarization.py +++ b/tests/test_diarization.py @@ -17,7 +17,6 @@ "whisper_type,vad_filter,bgm_separation,diarization", [ (WhisperImpl.WHISPER.value, False, False, True), - (WhisperImpl.FASTER_WHISPER.value, False, False, True), (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True) ] ) diff --git a/tests/test_transcription.py b/tests/test_transcription.py index f285a5a5..16f46e5c 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -88,7 +88,6 @@ def run_asr_pipeline( "whisper_type,vad_filter,bgm_separation,diarization", [ (WhisperImpl.WHISPER.value, False, False, False), - (WhisperImpl.FASTER_WHISPER.value, False, False, False), (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False) ] ) diff --git a/tests/test_vad.py b/tests/test_vad.py index d90a4e23..db069167 100644 --- a/tests/test_vad.py +++ b/tests/test_vad.py @@ -13,7 +13,6 @@ "whisper_type,vad_filter,bgm_separation,diarization", [ (WhisperImpl.WHISPER.value, True, False, False), - (WhisperImpl.FASTER_WHISPER.value, True, False, False), (WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False) ] )