Skip to content

Commit 9063788

Browse files
committed
schd: Requeue task if training server is unavailable
Restructure the logic around server response handling to allow for more fine-grained control. If the training server is reported unavailable due to a training job already running, the task is requeued instead of dropped. A shortcoming of this approach, given that RabbitMQ attempts to place requeued messages in their original positions or near the head of the queue, is that long-running tasks (e.g. training/evaluation) can block the execution of others that could be processed in parallel on the same servers (CMS only allows running a single training task at a time to work around resource limitations, but places no restrictions on the execution of short-running tasks). This could be addressed by introducing a separate queue for long-running tasks, or by increasing the priority of short-running tasks in the queue. In any case, we expect to have a more robust solution once we move to a more sophisticated, resource and task-aware scheduler. Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 3ab6821 commit 9063788

File tree

1 file changed

+87
-66
lines changed

1 file changed

+87
-66
lines changed

cogstack_model_gateway/scheduler/scheduler.py

Lines changed: 87 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import time
33

4-
import requests
4+
from requests import Response, request
55

66
from cogstack_model_gateway.common.object_store import ObjectStoreManager
77
from cogstack_model_gateway.common.queue import QueueManager
@@ -30,7 +30,6 @@ def run(self):
3030
self.queue_manager.consume(self.process_task)
3131

3232
def process_task(self, task: dict, ack: callable, nack: callable) -> None:
33-
# FIXME: Handle ACK and NACK appropriately
3433
task_uuid = task["uuid"]
3534
log.info(f"Processing task '{task_uuid}'")
3635

@@ -41,85 +40,41 @@ def process_task(self, task: dict, ack: callable, nack: callable) -> None:
4140
task_obj = self.handle_server_response(task_uuid, res, err_msg, ack, nack)
4241
self.send_notification(task_obj)
4342

44-
def route_task(self, task: dict) -> tuple[requests.Response, str]:
43+
def route_task(self, task: dict) -> tuple[Response, str]:
4544
log.info(f"Routing task '{task['uuid']}' to model server at {task['url']}")
46-
request = self._prepare_request(task)
47-
45+
req = self._prepare_request(task)
46+
response = None
4847
try:
49-
log.debug(f"Request: {request}")
50-
response = requests.request(
51-
method=request["method"],
52-
url=request["url"],
53-
headers=request["headers"],
54-
params=request["params"],
55-
data=request["data"],
56-
files=request["files"],
48+
log.debug(f"Request: {req}")
49+
response = request(
50+
method=req["method"],
51+
url=req["url"],
52+
headers=req["headers"],
53+
params=req["params"],
54+
data=req["data"],
55+
files=req["files"],
5756
)
58-
except Exception as e:
59-
err_msg = f"Failed to forward task '{task['uuid']}': {e}"
60-
log.error(err_msg)
61-
return None, err_msg
62-
63-
try:
6457
log.debug(f"Response: {response.text}")
6558
response.raise_for_status()
6659
log.info(f"Task '{task['uuid']}' forwarded successfully to {task['url']}")
6760
return response, None
68-
except requests.HTTPError:
69-
err_msg = f"Failed to process task '{task['uuid']}']: {response.json()}"
61+
except Exception as e:
62+
err_msg = f"Failed to forward task '{task['uuid']}': {e}"
7063
log.error(err_msg)
71-
return None, err_msg
64+
return response, err_msg
7265

7366
def handle_server_response(
7467
self,
7568
task_uuid: str,
76-
response: requests.Response,
69+
response: Response,
7770
err_msg: str,
7871
ack: callable,
7972
nack: callable,
8073
) -> Task:
81-
if response is None:
82-
# FIXME: Perhaps set task to a different status?
83-
# Pending and requeued? Or failed and done with?
84-
# Should we reprocess failed tasks? How can we tell transient failures?
85-
ack()
86-
return self.task_manager.update_task(
87-
task_uuid, status=Status.FAILED, error_message=err_msg or "Failed to process task"
88-
)
89-
ack()
90-
91-
if response.status_code == 202:
92-
log.info(f"Task '{task_uuid}' accepted for processing, waiting for results")
93-
tracking_id = response.json().get("run_id") if response.json() else None
94-
self.task_manager.update_task(
95-
task_uuid,
96-
status=Status.RUNNING,
97-
expected_status=Status.SCHEDULED,
98-
tracking_id=tracking_id,
99-
)
100-
101-
results = self.poll_task_status(task_uuid, tracking_id)
102-
if results["status"] == Status.FAILED:
103-
log.error(f"Task '{task_uuid}' failed: {results['error']}")
104-
return self.task_manager.update_task(
105-
task_uuid, status=Status.FAILED, error_message=str(results["error"])
106-
)
107-
else:
108-
log.info(f"Task '{task_uuid}' completed, writing results to object store")
109-
object_key = self.results_object_store_manager.upload_object(
110-
results["url"].encode(), "results.url", prefix=task_uuid
111-
)
112-
return self.task_manager.update_task(
113-
task_uuid, status=Status.SUCCEEDED, result=object_key
114-
)
74+
if response is None or response.status_code >= 400:
75+
return self._handle_task_failure(task_uuid, response, err_msg, nack)
11576
else:
116-
log.info(f"Task '{task_uuid}' completed, writing results to object store")
117-
object_key = self.results_object_store_manager.upload_object(
118-
response.content, "results.json", prefix=task_uuid
119-
)
120-
return self.task_manager.update_task(
121-
task_uuid, status=Status.SUCCEEDED, result=object_key
122-
)
77+
return self._handle_task_success(task_uuid, response, ack)
12378

12479
def poll_task_status(self, task_uuid: str, tracking_id: str = None) -> dict:
12580
while True:
@@ -137,8 +92,9 @@ def poll_task_status(self, task_uuid: str, tracking_id: str = None) -> dict:
13792
time.sleep(5)
13893

13994
def send_notification(self, task: Task):
140-
# FIXME: notify user
141-
log.info(f"Task '{task.uuid}' {task.status.value}: {task.result or task.error_message}")
95+
# FIXME: notify user if task is completed
96+
if task.status.is_final():
97+
log.info(f"Task '{task.uuid}' {task.status.value}: {task.result or task.error_message}")
14298

14399
def _get_payload_from_refs(self, refs: list) -> str:
144100
if len(refs) > 1:
@@ -180,3 +136,68 @@ def _prepare_request(self, task: dict) -> dict:
180136
"files": files,
181137
"headers": headers,
182138
}
139+
140+
def _handle_task_failure(
141+
self, task_uuid: str, response: Response, err_msg: str, nack: callable
142+
) -> Task:
143+
# FIXME: Add fine-grained error handling for different status codes
144+
if not response:
145+
nack(requeue=False)
146+
return self.task_manager.update_task(
147+
task_uuid, status=Status.FAILED, error_message=err_msg or "Failed to process task"
148+
)
149+
elif (
150+
response.status_code == 503
151+
and (experiment_id := response.json().get("experiment_id"))
152+
and (run_id := response.json().get("run_id"))
153+
):
154+
warn_msg = (
155+
f"Task '{task_uuid}' wasn't accepted for processing: a training run is already in"
156+
f" progress (experiment_id={experiment_id}, run_id={run_id}). Requeuing task..."
157+
)
158+
log.warning(warn_msg)
159+
nack()
160+
return self.task_manager.update_task(
161+
task_uuid, status=Status.PENDING, error_message=warn_msg
162+
)
163+
else:
164+
log.error(f"Task '{task_uuid}' failed with unexpected error: {response.text}")
165+
nack(requeue=False)
166+
return self.task_manager.update_task(
167+
task_uuid, status=Status.FAILED, error_message=response.text
168+
)
169+
170+
def _handle_task_success(self, task_uuid: str, response: Response, ack: callable) -> Task:
171+
ack()
172+
if response.status_code == 202:
173+
log.info(f"Task '{task_uuid}' accepted for processing, waiting for results")
174+
tracking_id = response.json().get("run_id") if response.json() else None
175+
self.task_manager.update_task(
176+
task_uuid,
177+
status=Status.RUNNING,
178+
expected_status=Status.SCHEDULED,
179+
tracking_id=tracking_id,
180+
)
181+
182+
results = self.poll_task_status(task_uuid, tracking_id)
183+
if results["status"] == Status.FAILED:
184+
log.error(f"Task '{task_uuid}' failed: {results['error']}")
185+
return self.task_manager.update_task(
186+
task_uuid, status=Status.FAILED, error_message=str(results["error"])
187+
)
188+
else:
189+
log.info(f"Task '{task_uuid}' completed, writing results to object store")
190+
object_key = self.results_object_store_manager.upload_object(
191+
results["url"].encode(), "results.url", prefix=task_uuid
192+
)
193+
return self.task_manager.update_task(
194+
task_uuid, status=Status.SUCCEEDED, result=object_key
195+
)
196+
else:
197+
log.info(f"Task '{task_uuid}' completed, writing results to object store")
198+
object_key = self.results_object_store_manager.upload_object(
199+
response.content, "results.json", prefix=task_uuid
200+
)
201+
return self.task_manager.update_task(
202+
task_uuid, status=Status.SUCCEEDED, result=object_key
203+
)

0 commit comments

Comments
 (0)