Skip to content

Commit 9701c73

Browse files
committed
Add asyncio lock to avoid race condition and refactor to encapsulate the function inside the server state
1 parent 9943bb6 commit 9701c73

File tree

1 file changed

+133
-66
lines changed

1 file changed

+133
-66
lines changed

lmms_eval/entrypoints/http_server.py

Lines changed: 133 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,129 @@ class HealthResponse(BaseModel):
101101

102102

103103
class ServerState:
104-
"""Global server state container."""
104+
"""Global server state container with thread-safe job management."""
105105

106106
def __init__(self):
107107
self.job_queue: asyncio.Queue = None
108-
self.jobs: Dict[str, JobInfo] = {}
108+
self._jobs: Dict[str, JobInfo] = {}
109+
self._jobs_lock: asyncio.Lock = None
109110
self.worker_task: asyncio.Task = None
110111
self.current_job_id: Optional[str] = None
111112

112113
def reset(self):
113114
self.job_queue = asyncio.Queue()
114-
self.jobs = {}
115+
self._jobs = {}
116+
self._jobs_lock = asyncio.Lock()
115117
self.worker_task = None
116118
self.current_job_id = None
117119

120+
# -------------------------------------------------------------------------
121+
# Thread-safe job operations
122+
# -------------------------------------------------------------------------
123+
124+
async def get_job(self, job_id: str) -> Optional[JobInfo]:
125+
"""Get a job by ID (thread-safe)."""
126+
async with self._jobs_lock:
127+
return self._jobs.get(job_id)
128+
129+
async def get_job_with_position(self, job_id: str) -> Optional[JobInfo]:
130+
"""Get a job by ID, updating queue position if queued (thread-safe)."""
131+
async with self._jobs_lock:
132+
job = self._jobs.get(job_id)
133+
if job is None:
134+
return None
135+
136+
if job.status == JobStatus.QUEUED:
137+
position = sum(1 for j in self._jobs.values() if j.status == JobStatus.QUEUED and j.created_at < job.created_at)
138+
job.position_in_queue = position
139+
140+
return job
141+
142+
async def add_job(self, request: "EvaluateRequest") -> tuple[str, int]:
143+
"""Create and queue a new job. Returns (job_id, position)."""
144+
job_id = str(uuid.uuid4())
145+
146+
async with self._jobs_lock:
147+
position = self.job_queue.qsize()
148+
job = JobInfo(
149+
job_id=job_id,
150+
status=JobStatus.QUEUED,
151+
created_at=datetime.now().isoformat(),
152+
request=request,
153+
position_in_queue=position,
154+
)
155+
self._jobs[job_id] = job
156+
await self.job_queue.put(job_id)
157+
158+
return job_id, position
159+
160+
async def start_job(self, job_id: str) -> Optional[dict]:
161+
"""
162+
Mark a job as running and return its config.
163+
Returns None if job doesn't exist or is cancelled.
164+
"""
165+
async with self._jobs_lock:
166+
job = self._jobs.get(job_id)
167+
if job is None or job.status == JobStatus.CANCELLED:
168+
return None
169+
170+
self.current_job_id = job_id
171+
job.status = JobStatus.RUNNING
172+
job.started_at = datetime.now().isoformat()
173+
return job.request.model_dump()
174+
175+
async def complete_job(self, job_id: str, result: Dict[str, Any]):
176+
"""Mark a job as completed with results."""
177+
async with self._jobs_lock:
178+
job = self._jobs.get(job_id)
179+
if job:
180+
job.status = JobStatus.COMPLETED
181+
job.completed_at = datetime.now().isoformat()
182+
job.result = result
183+
184+
async def fail_job(self, job_id: str, error: str):
185+
"""Mark a job as failed with error message."""
186+
async with self._jobs_lock:
187+
job = self._jobs.get(job_id)
188+
if job:
189+
job.status = JobStatus.FAILED
190+
job.completed_at = datetime.now().isoformat()
191+
job.error = error
192+
193+
async def cancel_job(self, job_id: str) -> tuple[bool, str]:
194+
"""
195+
Cancel a queued job.
196+
Returns (success, message) tuple.
197+
"""
198+
async with self._jobs_lock:
199+
job = self._jobs.get(job_id)
200+
if job is None:
201+
return False, f"Job {job_id} not found"
202+
203+
if job.status == JobStatus.RUNNING:
204+
return False, "Cannot cancel a running job"
205+
206+
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
207+
return False, "Job already finished or cancelled"
208+
209+
job.status = JobStatus.CANCELLED
210+
job.completed_at = datetime.now().isoformat()
211+
return True, f"Job {job_id} cancelled"
212+
213+
async def get_queue_stats(self) -> dict:
214+
"""Get queue statistics (thread-safe)."""
215+
async with self._jobs_lock:
216+
queued = [jid for jid, j in self._jobs.items() if j.status == JobStatus.QUEUED]
217+
completed = sum(1 for j in self._jobs.values() if j.status == JobStatus.COMPLETED)
218+
failed = sum(1 for j in self._jobs.values() if j.status == JobStatus.FAILED)
219+
220+
return {
221+
"queued": queued,
222+
"completed": completed,
223+
"failed": failed,
224+
"running_job": self.current_job_id,
225+
}
226+
118227

119228
state = ServerState()
120229

@@ -252,31 +361,20 @@ async def job_worker():
252361
while True:
253362
try:
254363
job_id = await state.job_queue.get()
255-
job = state.jobs.get(job_id)
256364

257-
if job is None or job.status == JobStatus.CANCELLED:
365+
# Start job and get config (returns None if cancelled/missing)
366+
config = await state.start_job(job_id)
367+
if config is None:
258368
state.job_queue.task_done()
259369
continue
260370

261-
# Update job status
262-
state.current_job_id = job_id
263-
job.status = JobStatus.RUNNING
264-
job.started_at = datetime.now().isoformat()
265-
266371
try:
267-
# Run evaluation
268-
config = job.request.model_dump()
372+
# Run evaluation (outside lock to allow other operations)
269373
result = await run_evaluation_subprocess(config)
270-
271-
# Update job with results
272-
job.status = JobStatus.COMPLETED
273-
job.completed_at = datetime.now().isoformat()
274-
job.result = result
374+
await state.complete_job(job_id, result)
275375

276376
except Exception as e:
277-
job.status = JobStatus.FAILED
278-
job.completed_at = datetime.now().isoformat()
279-
job.error = str(e)
377+
await state.fail_job(job_id, str(e))
280378

281379
finally:
282380
state.current_job_id = None
@@ -346,21 +444,7 @@ async def submit_evaluation(request: EvaluateRequest):
346444
The job will be queued and processed sequentially.
347445
Use GET /jobs/{job_id} to check status and retrieve results.
348446
"""
349-
job_id = str(uuid.uuid4())
350-
position = state.job_queue.qsize()
351-
352-
# Create job info
353-
job = JobInfo(
354-
job_id=job_id,
355-
status=JobStatus.QUEUED,
356-
created_at=datetime.now().isoformat(),
357-
request=request,
358-
position_in_queue=position,
359-
)
360-
state.jobs[job_id] = job
361-
362-
# Add to queue
363-
await state.job_queue.put(job_id)
447+
job_id, position = await state.add_job(request)
364448

365449
return JobSubmitResponse(
366450
job_id=job_id,
@@ -373,35 +457,23 @@ async def submit_evaluation(request: EvaluateRequest):
373457
@app.get("/jobs/{job_id}", response_model=JobInfo)
374458
async def get_job_status(job_id: str):
375459
"""Get the status and results of a job."""
376-
job = state.jobs.get(job_id)
460+
job = await state.get_job_with_position(job_id)
377461
if job is None:
378462
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
379-
380-
# Update position in queue if still queued
381-
if job.status == JobStatus.QUEUED:
382-
# Count jobs ahead in queue
383-
position = 0
384-
for jid, j in state.jobs.items():
385-
if j.status == JobStatus.QUEUED and j.created_at < job.created_at:
386-
position += 1
387-
job.position_in_queue = position
388-
389463
return job
390464

391465

392466
@app.get("/queue", response_model=QueueStatusResponse)
393467
async def get_queue_status():
394468
"""Get the current queue status."""
395-
queued = [jid for jid, j in state.jobs.items() if j.status == JobStatus.QUEUED]
396-
completed = sum(1 for j in state.jobs.values() if j.status == JobStatus.COMPLETED)
397-
failed = sum(1 for j in state.jobs.values() if j.status == JobStatus.FAILED)
469+
stats = await state.get_queue_stats()
398470

399471
return QueueStatusResponse(
400-
queue_size=len(queued),
401-
running_job=state.current_job_id,
402-
queued_jobs=queued,
403-
completed_jobs=completed,
404-
failed_jobs=failed,
472+
queue_size=len(stats["queued"]),
473+
running_job=stats["running_job"],
474+
queued_jobs=stats["queued"],
475+
completed_jobs=stats["completed"],
476+
failed_jobs=stats["failed"],
405477
)
406478

407479

@@ -412,20 +484,15 @@ async def cancel_job(job_id: str):
412484
413485
Note: Running jobs cannot be cancelled.
414486
"""
415-
job = state.jobs.get(job_id)
416-
if job is None:
417-
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
418-
419-
if job.status == JobStatus.RUNNING:
420-
raise HTTPException(status_code=400, detail="Cannot cancel a running job")
421-
422-
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
423-
raise HTTPException(status_code=400, detail="Job already finished or cancelled")
487+
success, message = await state.cancel_job(job_id)
424488

425-
job.status = JobStatus.CANCELLED
426-
job.completed_at = datetime.now().isoformat()
489+
if not success:
490+
# Determine appropriate status code based on message
491+
if "not found" in message:
492+
raise HTTPException(status_code=404, detail=message)
493+
raise HTTPException(status_code=400, detail=message)
427494

428-
return {"message": f"Job {job_id} cancelled"}
495+
return {"message": message}
429496

430497

431498
@app.get("/tasks")

0 commit comments

Comments
 (0)