11import json
2+ import uuid
23from typing import Any
34
45import aiohttp
56from fastapi import APIRouter
67
8+ from transcribo_backend .models .progress import ProgressResponse
79from transcribo_backend .models .response_format import ResponseFormat
810from transcribo_backend .models .task_status import TaskStatus
911from transcribo_backend .models .transcription_response import TranscriptionResponse
1719BENTOML_API_URL = f"{ settings .whisper_api } /audio/transcriptions"
1820
1921
22+ taskId_to_progressId : dict [str , str ] = {}
23+
24+
2025async def transcribe_get_task_status (task_id : str ) -> TaskStatus :
2126 """
2227 Checks the status of an ongoing transcription task.
@@ -28,11 +33,19 @@ async def transcribe_get_task_status(task_id: str) -> TaskStatus:
2833 TaskStatus: The current status of the task
2934 """
3035 url = f"{ settings .whisper_api } /audio/transcriptions/task/status?task_id={ task_id } "
36+ progress_url = f"{ settings .whisper_api } /progress/{ taskId_to_progressId [task_id ]} "
3137
3238 # Get the status of the transcription task
33- async with aiohttp .ClientSession () as session , session .get (url ) as response :
39+ async with (
40+ aiohttp .ClientSession () as session ,
41+ session .get (url ) as response ,
42+ session .get (progress_url ) as progress_response ,
43+ ):
3444 response .raise_for_status ()
35- return TaskStatus (** await response .json ())
45+ progress_response .raise_for_status ()
46+
47+ progress = ProgressResponse (** await progress_response .json ())
48+ return TaskStatus (** await response .json (), progress = progress .progress )
3649
3750
3851async def transcribe_get_task_result (task_id : str ) -> TranscriptionResponse :
@@ -51,6 +64,15 @@ async def transcribe_get_task_result(task_id: str) -> TranscriptionResponse:
5164 async with aiohttp .ClientSession () as session , session .get (url ) as response :
5265 response .raise_for_status ()
5366 result_data = await response .json ()
67+
68+ taskId_to_progressId .pop (task_id , None )
69+
70+ transcription = TranscriptionResponse (** result_data )
71+ for segment in transcription .segments :
72+ segment .text = segment .text .strip ()
73+ segment .speaker = segment .speaker or "Unknown"
74+ segment .speaker = segment .speaker .strip ().capitalize ()
75+
5476 return TranscriptionResponse (** result_data )
5577
5678
@@ -135,6 +157,9 @@ async def transcribe_submit_task(
135157 form_data .add_field ("file" , audio_data , filename = "audio.wav" )
136158 form_data .add_field ("model" , model )
137159
160+ progress_id = uuid .uuid4 ().hex
161+ form_data .add_field ("progress_id" , progress_id )
162+
138163 if language :
139164 form_data .add_field ("language" , language )
140165 if prompt :
@@ -166,4 +191,6 @@ async def transcribe_submit_task(
166191 # Send the request
167192 async with aiohttp .ClientSession () as session , session .post (url , data = form_data ) as response :
168193 response .raise_for_status ()
169- return TaskStatus (** await response .json ())
194+ status = TaskStatus (** await response .json ())
195+ taskId_to_progressId [status .task_id ] = progress_id
196+ return status
0 commit comments