33
44import asyncio
55from collections .abc import AsyncGenerator
6+ from pathlib import Path
67from typing import Annotated
8+ from uuid import UUID
79
8- from fastapi import APIRouter , Body , Depends , HTTPException , Request , status
9- from starlette .responses import StreamingResponse
10+ import aiofiles
11+ from fastapi import APIRouter , Body , Depends , HTTPException , status
12+ from sse_starlette .sse import EventSourceResponse , ServerSentEvent
1013
11- from app .api .dependencies import get_job_queue , get_project_service
14+ from app .api .dependencies import get_data_dir , get_job_dir , get_job_queue , get_project_service
1215from app .api .validators import JobID
1316from app .core .jobs .control_plane import CancellationResult , JobQueue
1417from app .core .jobs .models import JobStatus
3336async def submit_job (
3437 job_request : Annotated [JobRequest , Body ()],
3538 job_queue : Annotated [JobQueue , Depends (get_job_queue )],
39+ job_dir : Annotated [Path , Depends (get_job_dir )],
40+ data_dir : Annotated [Path , Depends (get_data_dir )],
3641 project_service : Annotated [ProjectService , Depends (get_project_service )],
3742) -> JobView :
3843 """
@@ -41,6 +46,8 @@ async def submit_job(
4146 Args:
4247 job_request (JobRequest): The Job request payload.
4348 job_queue (JobQueue): The job queue instance responsible for managing job submissions and tracking job statuses.
49+ job_dir (Path): The directory where job log files are stored.
50+ data_dir (Path): The base directory for project data storage.
4451 project_service (ProjectService): The service to interact with project data.
4552
4653 Returns:
@@ -53,6 +60,8 @@ async def submit_job(
5360 case JobType .TRAIN :
5461 job = TrainingJob (
5562 project_id = job_request .project_id ,
63+ log_dir = job_dir ,
64+ data_dir = data_dir ,
5665 params = TrainingParams (
5766 model_architecture_id = job_request .parameters .model_architecture_id ,
5867 parent_model_revision_id = job_request .parameters .parent_model_revision_id ,
@@ -163,49 +172,106 @@ async def cancel_job(job_id: JobID, job_queue: Annotated[JobQueue, Depends(get_j
163172
164173@router .get ("/{job_id}/status" )
165174async def stream_job_status (
166- job_id : JobID , request : Request , job_queue : Annotated [JobQueue , Depends (get_job_queue )]
167- ) -> StreamingResponse :
175+ job_id : JobID , job_queue : Annotated [JobQueue , Depends (get_job_queue )]
176+ ) -> EventSourceResponse :
168177 """
169178 Stream real-time status updates for a specific job.
170179
171180 This endpoint streams job status updates using Server-Sent Events (SSE).
172- It sends periodic updates until the client disconnects or the job reaches
173- terminal state.
181+ It sends periodic updates until the job reaches terminal state.
174182
175183 Args:
176184 job_id (JobID): The unique identifier of the job.
177- request (Request): The HTTP request object to monitor client connection status.
178- job_queue (JobQueue): The job queue instance responsible for managing job submissions and tracking job statuses.
185+ job_queue (JobQueue): The job queue instance responsible for tracking job statuses.
179186
180187 Returns:
181- StreamingResponse : A streaming response with job status updates.
188+ EventSourceResponse : A streaming response with job status updates.
182189 """
183190 if not job_queue .get (job_id ):
184191 raise HTTPException (status_code = status .HTTP_404_NOT_FOUND , detail = "Job not found" )
185192
186- async def gen_job_updates () -> AsyncGenerator [str ]:
187- """Generate job status updates."""
188- last = None
189- while True :
190- if await request .is_disconnected ():
191- break
192- j = job_queue .get (job_id )
193- if not j :
194- break
195- snap = JobView .of (j ).model_dump_json ()
196- if snap != last :
197- yield f"{ snap } \n "
198- last = snap
199- if j .status >= JobStatus .DONE :
200- break
201- await asyncio .sleep (0.1 )
202-
203- return StreamingResponse (
204- gen_job_updates (),
205- media_type = "text/event-stream" ,
206- headers = {
207- "Content-Type" : "text/event-stream" ,
208- "Connection" : "keep-alive" ,
209- "Cache-Control" : "no-cache" ,
210- },
211- )
193+ return EventSourceResponse (__gen_job_updates (job_id , job_queue ))
194+
195+
196+ @router .get ("/{job_id}/logs" )
197+ async def stream_job_logs (
198+ job_id : JobID ,
199+ job_dir : Annotated [Path , Depends (get_job_dir )],
200+ job_queue : Annotated [JobQueue , Depends (get_job_queue )],
201+ ) -> EventSourceResponse :
202+ """
203+ Stream real-time log output for a specific job.
204+
205+ This endpoint streams job logs using Server-Sent Events (SSE). It reads
206+ the job's log file and yields new lines as they are written, allowing clients
207+ to follow the job's progress in real-time. The stream continues until the
208+ client disconnects or an error occurs.
209+
210+ Args:
211+ job_id (JobID): The unique identifier of the job.
212+ job_dir (Path): The directory where job log files are stored.
213+ job_queue (JobQueue): The job queue instance for tracking job statuses.
214+
215+ Returns:
216+ EventSourceResponse: A streaming response with log entries sent as SSE events.
217+
218+ Raises:
219+ HTTPException: If the job is not found (404), the log file doesn't exist (404),
220+ or the job has already completed (409).
221+ """
222+ job = job_queue .get (job_id )
223+ if not job :
224+ raise HTTPException (status_code = status .HTTP_404_NOT_FOUND , detail = "Job not found" )
225+
226+ if job .status >= JobStatus .DONE :
227+ raise HTTPException (
228+ status_code = status .HTTP_409_CONFLICT ,
229+ detail = "Job has already completed; logs are no longer available for streaming" ,
230+ )
231+
232+ log_path = job_dir / job .log_file
233+
234+ if not log_path .exists ():
235+ raise HTTPException (status_code = status .HTTP_404_NOT_FOUND , detail = "Log file not found" )
236+
237+ return EventSourceResponse (__gen_log_stream (job_id , log_path , job_queue ))
238+
239+
240+ async def __gen_job_updates (job_id : UUID , job_queue : JobQueue ) -> AsyncGenerator [ServerSentEvent ]:
241+ """Generate job status updates."""
242+ last = None
243+ while True :
244+ j = job_queue .get (job_id )
245+ if not j :
246+ break
247+ snap = JobView .of (j ).model_dump_json ()
248+ if snap != last :
249+ yield ServerSentEvent (data = snap )
250+ last = snap
251+ if j .status >= JobStatus .DONE :
252+ break
253+ await asyncio .sleep (0.1 )
254+
255+
256+ async def __gen_log_stream (job_id : UUID , log_path : Path , job_queue : JobQueue ) -> AsyncGenerator [ServerSentEvent ]:
257+ """Asynchronously follow a log file and yield new lines as SSE events."""
258+ try :
259+ async with aiofiles .open (log_path ) as f :
260+ async for line in f :
261+ yield ServerSentEvent (data = line .rstrip ("\n " ))
262+
263+ while True :
264+ j = job_queue .get (job_id )
265+ if not j :
266+ break
267+ line = await f .readline ()
268+ if not line :
269+ await asyncio .sleep (0.3 )
270+ continue
271+ yield ServerSentEvent (data = line .rstrip ("\n " ))
272+ if j .status >= JobStatus .DONE :
273+ break
274+ except asyncio .CancelledError :
275+ raise
276+ except Exception as e :
277+ yield ServerSentEvent (data = f"Error reading log file: { e } " )
0 commit comments