Skip to content

Commit 1e92112

Browse files
Pause and cancel for task master (#255)
* cancel markdown and macro * return + delete by task id * cancel macro * cancel tmp doc * remove print * model * attribute cancel * attr cancel * simple payload cancel check * improve cancel * model + camel case * task queue pause info * model * directlyy set to failed etl * model merge * model update --------- Co-authored-by: LennartSchmidtKern <[email protected]>
1 parent 5b9feb0 commit 1e92112

File tree

8 files changed

+162
-42
lines changed

8 files changed

+162
-42
lines changed

controller/attribute/manager.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,11 @@ def __calculate_user_attribute_all_records(
309309
util.add_log_to_attribute_logs(project_id, attribute_id, "Finished writing.")
310310

311311
attribute_item = attribute.get(project_id, attribute_id)
312-
if attribute_item.data_type == DataTypes.TEXT.value:
312+
if (
313+
attribute_item
314+
and attribute_item.data_type == DataTypes.TEXT.value
315+
and not attribute_item.state == AttributeState.FAILED.value
316+
):
313317
util.add_log_to_attribute_logs(
314318
project_id, attribute_id, "Triggering tokenization."
315319
)
@@ -346,6 +350,15 @@ def __calculate_user_attribute_all_records(
346350
)
347351
request_reupload_docbins(project_id)
348352

353+
attribute_item = attribute.get(project_id, attribute_id)
354+
if attribute_item.state == AttributeState.FAILED.value:
355+
__notify_attribute_calculation_failed(
356+
project_id=project_id,
357+
attribute_id=attribute_id,
358+
log="Writing to the database failed.",
359+
)
360+
general.remove_and_refresh_session(session_token)
361+
return
349362
util.set_progress(project_id, attribute_item, 1.0)
350363
attribute.update(
351364
project_id=project_id,

controller/attribute/util.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from submodules.s3 import controller as s3
2121
from util import daemon, notification
2222
from controller.knowledge_base import util as knowledge_base
23+
from submodules.model import enums
2324

2425
client = docker.from_env()
2526
image = os.getenv("AC_EXEC_ENV_IMAGE")
@@ -189,7 +190,11 @@ def read_container_logs_thread(
189190
c += 1
190191
if c > 100:
191192
ctx_token = general.remove_and_refresh_session(ctx_token, True)
192-
attribute_item = attribute.get(project_id, attribute_id)
193+
attribute_item = attribute.get(project_id, attribute_id)
194+
if not attribute_item:
195+
break
196+
if attribute_item.state == enums.AttributeState.FAILED.value:
197+
break
193198
if not name in __containers_running:
194199
break
195200
try:

controller/monitor/manager.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Any, List
1+
from typing import Any, List, Dict
22
from submodules.model.business_objects import monitor as task_monitor
33
from controller.auth import kratos
44
from submodules.model.util import sql_alchemy_to_dict
5+
from submodules.s3 import controller as s3
56

67

78
def monitor_all_tasks(page: int, limit: int) -> List[Any]:
@@ -33,36 +34,85 @@ def cancel_upload_task(project_id: str = None, upload_task_id: str = None) -> No
3334
task_monitor.set_upload_task_to_failed(project_id, upload_task_id, with_commit=True)
3435

3536

36-
def cancel_weak_supervision(project_id: str = None, payload_id: str = None) -> None:
37-
task_monitor.set_weak_supervision_to_failed(
38-
project_id, payload_id, with_commit=True
39-
)
37+
def cancel_weak_supervision(
38+
task_info: Dict[str, Any],
39+
) -> None:
40+
project_id = task_info.get("projectId")
41+
payload_id = task_info.get("payloadId")
42+
if project_id and payload_id:
43+
task_monitor.set_weak_supervision_to_failed(
44+
project_id, payload_id, with_commit=True
45+
)
4046

4147

4248
def cancel_attribute_calculation(
43-
project_id: str = None, attribute_id: str = None
49+
task_info: Dict[str, Any],
4450
) -> None:
45-
task_monitor.set_attribute_calculation_to_failed(
46-
project_id, attribute_id, with_commit=True
47-
)
4851

52+
project_id = task_info.get("projectId")
53+
attribute_id = task_info.get("attributeId")
54+
if project_id and attribute_id:
55+
task_monitor.set_attribute_calculation_to_failed(
56+
project_id, attribute_id, with_commit=True
57+
)
4958

50-
def cancel_embedding(project_id: str = None, embedding_id: str = None) -> None:
51-
task_monitor.set_embedding_to_failed(project_id, embedding_id, with_commit=True)
59+
60+
def cancel_embedding(
61+
task_info: Dict[str, Any],
62+
) -> None:
63+
project_id = task_info.get("projectId")
64+
embedding_id = task_info.get("embeddingId")
65+
if project_id and embedding_id:
66+
task_monitor.set_embedding_to_failed(project_id, embedding_id, with_commit=True)
5267

5368

5469
def cancel_information_source_payload(
55-
project_id: str = None, payload_id: str = None
70+
task_info: Dict[str, Any],
5671
) -> None:
57-
task_monitor.set_information_source_payloads_to_failed(
58-
project_id, payload_id, with_commit=True
59-
)
72+
project_id = task_info.get("projectId")
73+
payload_id = task_info.get("payloadId")
74+
if project_id and payload_id:
75+
task_monitor.set_information_source_payloads_to_failed(
76+
project_id, payload_id, with_commit=True
77+
)
6078

6179

6280
def cancel_record_tokenization_task(
63-
project_id: str = None,
64-
tokenization_task_id: str = None,
81+
task_info: Dict[str, Any],
82+
) -> None:
83+
project_id = task_info.get("projectId")
84+
tokenization_task_id = task_info.get("recordTokenizationTaskId")
85+
if project_id and tokenization_task_id:
86+
task_monitor.set_record_tokenization_task_to_failed(
87+
project_id, tokenization_task_id, with_commit=True
88+
)
89+
90+
91+
def cancel_macro_execution_task(
92+
task_info: Dict[str, Any],
6593
) -> None:
66-
task_monitor.set_record_tokenization_task_to_failed(
67-
project_id, tokenization_task_id, with_commit=True
94+
95+
macro_execution_id = task_info.get("executionId")
96+
macro_execution_group_id = task_info.get("groupExecutionId")
97+
98+
task_monitor.set_macro_execution_task_to_failed(
99+
macro_execution_id, macro_execution_group_id, with_commit=True
68100
)
101+
102+
103+
def cancel_markdown_file_task(
104+
task_info: Dict[str, Any],
105+
) -> None:
106+
markdown_file_id = task_info.get("fileId")
107+
org_id = task_info.get("orgId")
108+
task_monitor.set_markdown_file_task_to_failed(
109+
markdown_file_id, org_id, with_commit=True
110+
)
111+
112+
113+
def cancel_tmp_doc_retrieval_task(
114+
task_info: Dict[str, Any],
115+
) -> None:
116+
bucket = task_info.get("bucket")
117+
minio_path = task_info.get("minioPath")
118+
s3.delete_object(bucket, minio_path)

controller/payload/payload_scheduler.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,21 @@ def update_records(
533533
) -> bool:
534534
org_id = organization.get_id_by_project_id(project_id)
535535
tmp_log_store = information_source_payload.logs
536+
537+
if information_source_payload.state == enums.PayloadState.FAILED.value:
538+
berlin_now = datetime.datetime.now(__tz)
539+
tmp_log_store.append(
540+
" ".join(
541+
[
542+
berlin_now.strftime("%Y-%m-%dT%H:%M:%S"),
543+
"Information source task cancelled.",
544+
]
545+
)
546+
)
547+
information_source_payload.logs = tmp_log_store
548+
flag_modified(information_source_payload, "logs")
549+
general.commit()
550+
return True
536551
try:
537552
output_data = json.loads(
538553
s3.get_object(

controller/task_master/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,7 @@ def pause_task_queue(task_queue_pause: bool) -> requests.Response:
3535
return requests.post(
3636
f"{TASK_MASTER_URL}/task/queue/pause?task_queue_pause={task_queue_pause}"
3737
)
38+
39+
40+
def get_task_queue_pause() -> requests.Response:
41+
return requests.get(f"{TASK_MASTER_URL}/task/queue/pause")

fast_api/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,8 @@ class DeleteUserBody(BaseModel):
433433

434434

435435
class CancelTaskBody(BaseModel):
436-
project_id: StrictStr
437436
task_id: StrictStr
437+
task_info: Dict[StrictStr, Any]
438438
task_type: StrictStr
439439

440440

fast_api/routes/misc.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
CustomerButtonType,
2727
CustomerButtonLocation,
2828
)
29+
from submodules.model.business_objects import task_queue as task_queue_bo
2930

3031
router = APIRouter()
3132

@@ -114,24 +115,37 @@ def cancel_task(
114115

115116
auth.check_admin_access(request.state.info)
116117
task_type = body.task_type
117-
project_id = body.project_id
118+
task_info = body.task_info
118119
task_id = body.task_id
119120

120-
if task_type == enums.TaskType.ATTRIBUTE_CALCULATION.value:
121-
controller_manager.cancel_attribute_calculation(project_id, task_id)
122-
elif task_type == enums.TaskType.EMBEDDING.value:
123-
controller_manager.cancel_embedding(project_id, task_id)
124-
elif task_type == enums.TaskType.INFORMATION_SOURCE.value:
125-
controller_manager.cancel_information_source_payload(project_id, task_id)
126-
elif task_type == enums.TaskType.TOKENIZATION.value:
127-
controller_manager.cancel_record_tokenization_task(project_id, task_id)
128-
elif task_type == enums.TaskType.UPLOAD_TASK.value:
129-
controller_manager.cancel_upload_task(project_id, task_id)
130-
elif task_type == enums.TaskType.WEAK_SUPERVISION.value:
131-
controller_manager.cancel_weak_supervision(project_id, task_id)
132-
else:
133-
raise ValueError(f"{task_type} is no valid task type")
134-
121+
task_entity = task_queue_bo.get(task_id)
122+
if not task_entity:
123+
return pack_json_result({"data": {"cancelTask": {"ok": False}}})
124+
if task_entity and (
125+
task_entity.is_active or task_type == enums.TaskType.PARSE_MARKDOWN_FILE.value
126+
):
127+
if task_type == enums.TaskType.ATTRIBUTE_CALCULATION.value:
128+
controller_manager.cancel_attribute_calculation(task_info)
129+
elif task_type == enums.TaskType.EMBEDDING.value:
130+
controller_manager.cancel_embedding(task_info)
131+
elif task_type == enums.TaskType.INFORMATION_SOURCE.value:
132+
controller_manager.cancel_information_source_payload(task_info)
133+
elif task_type == enums.TaskType.TOKENIZATION.value:
134+
controller_manager.cancel_record_tokenization_task(task_info)
135+
elif task_type == enums.TaskType.UPLOAD_TASK.value:
136+
controller_manager.cancel_upload_task(task_info)
137+
elif task_type == enums.TaskType.WEAK_SUPERVISION.value:
138+
controller_manager.cancel_weak_supervision(task_info)
139+
elif task_type == enums.TaskType.RUN_COGNITION_MACRO.value:
140+
controller_manager.cancel_macro_execution_task(task_info)
141+
elif task_type == enums.TaskType.PARSE_MARKDOWN_FILE.value:
142+
controller_manager.cancel_markdown_file_task(task_info)
143+
elif task_type == enums.TaskType.PARSE_COGNITION_TMP_FILE.value:
144+
controller_manager.cancel_tmp_doc_retrieval_task(task_info)
145+
else:
146+
raise ValueError(f"{task_type} is no valid task type")
147+
148+
task_queue_bo.delete_by_task_id(task_id, True)
135149
return pack_json_result({"data": {"cancelTask": {"ok": True}}})
136150

137151

@@ -142,11 +156,30 @@ def cancel_all_running_tasks(request: Request):
142156
return pack_json_result({"data": {"cancelAllRunningTasks": {"ok": True}}})
143157

144158

145-
@router.get("/pause-task-queue")
159+
@router.post("/pause-task-queue")
146160
def pause_task_queue(request: Request, task_queue_pause: bool):
147161
auth.check_admin_access(request.state.info)
148-
task_master_manager.pause_task_queue(task_queue_pause)
149-
return SILENT_SUCCESS_RESPONSE
162+
task_queue_pause_response = task_master_manager.pause_task_queue(task_queue_pause)
163+
task_queue_pause = False
164+
if task_queue_pause_response.ok:
165+
try:
166+
task_queue_pause = task_queue_pause_response.json()["task_queue_pause"]
167+
except Exception:
168+
task_queue_pause = False
169+
return pack_json_result({"taskQueuePause": task_queue_pause})
170+
171+
172+
@router.get("/pause-task-queue")
173+
def get_task_queue_pause(request: Request):
174+
auth.check_admin_access(request.state.info)
175+
task_queue_pause_response = task_master_manager.get_task_queue_pause()
176+
task_queue_pause = False
177+
if task_queue_pause_response.ok:
178+
try:
179+
task_queue_pause = task_queue_pause_response.json()["task_queue_pause"]
180+
except Exception:
181+
task_queue_pause = False
182+
return pack_json_result({"taskQueuePause": task_queue_pause})
150183

151184

152185
@router.get("/all-users-activity")

0 commit comments

Comments
 (0)