Skip to content

Commit 7413437

Browse files
Task Master (#241)
* clean task queue * remove task endpoints * alembic * add itnernal endpoint dummy * offer task executions * information source endpoint * return payload id * information source, project id handling * import wizard org ids * project id tokenizer * attribute calculation, embedding * embedding * data slice * gate removal * remove start gates * task executions, wizard improvements * improve information source * clean * project id * model * task deletion * replace alembic + model * PR updates * blank lines * clean * model merge * test * model dev
1 parent 7a9c554 commit 7413437

35 files changed

+393
-1138
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""rework task queue
2+
3+
Revision ID: 881102ae15f8
4+
Revises: 37d138040614
5+
Create Date: 2024-08-08 12:57:27.648167
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
from sqlalchemy.dialects import postgresql
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '881102ae15f8'
14+
down_revision = '37d138040614'
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.create_table('task_queue',
22+
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
23+
sa.Column('organization_id', postgresql.UUID(as_uuid=True), nullable=True),
24+
sa.Column('task_type', sa.String(), nullable=True),
25+
sa.Column('task_info', sa.JSON(), nullable=True),
26+
sa.Column('priority', sa.Boolean(), nullable=True),
27+
sa.Column('is_active', sa.Boolean(), nullable=True),
28+
sa.Column('created_at', sa.DateTime(), nullable=True),
29+
sa.Column('created_by', postgresql.UUID(as_uuid=True), nullable=True),
30+
sa.ForeignKeyConstraint(['created_by'], ['user.id'], ondelete='SET NULL'),
31+
sa.ForeignKeyConstraint(['organization_id'], ['organization.id'], ondelete='CASCADE'),
32+
sa.PrimaryKeyConstraint('id'),
33+
schema='global'
34+
)
35+
op.create_index(op.f('ix_global_task_queue_created_by'), 'task_queue', ['created_by'], unique=False, schema='global')
36+
op.create_index(op.f('ix_global_task_queue_organization_id'), 'task_queue', ['organization_id'], unique=False, schema='global')
37+
op.drop_index('ix_task_queue_created_by', table_name='task_queue')
38+
op.drop_table('task_queue')
39+
# ### end Alembic commands ###
40+
41+
42+
def downgrade():
43+
# ### commands auto generated by Alembic - please adjust! ###
44+
op.create_table('task_queue',
45+
sa.Column('id', postgresql.UUID(), autoincrement=False, nullable=False),
46+
sa.Column('project_id', postgresql.UUID(), autoincrement=False, nullable=True),
47+
sa.Column('task_type', sa.VARCHAR(), autoincrement=False, nullable=True),
48+
sa.Column('task_info', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
49+
sa.Column('priority', sa.BOOLEAN(), autoincrement=False, nullable=True),
50+
sa.Column('is_active', sa.BOOLEAN(), autoincrement=False, nullable=True),
51+
sa.Column('created_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
52+
sa.Column('created_by', postgresql.UUID(), autoincrement=False, nullable=True),
53+
sa.ForeignKeyConstraint(['created_by'], ['user.id'], name='task_queue_created_by_fkey', ondelete='SET NULL'),
54+
sa.ForeignKeyConstraint(['project_id'], ['project.id'], name='task_queue_project_id_fkey', ondelete='CASCADE'),
55+
sa.PrimaryKeyConstraint('id', name='task_queue_pkey')
56+
)
57+
op.create_index('ix_task_queue_created_by', 'task_queue', ['created_by'], unique=False)
58+
op.drop_index(op.f('ix_global_task_queue_organization_id'), table_name='task_queue', schema='global')
59+
op.drop_index(op.f('ix_global_task_queue_created_by'), table_name='task_queue', schema='global')
60+
op.drop_table('task_queue', schema='global')
61+
# ### end Alembic commands ###

api/project.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from submodules.model import events
1111
from util import doc_ock, notification, adapter
1212

13-
from controller.task_queue import manager as task_queue_manager
13+
from controller.task_master import manager as task_master_manager
1414
from submodules.model.enums import TaskType, RecordTokenizationScope
1515

1616
logging.basicConfig(level=logging.DEBUG)
@@ -71,17 +71,22 @@ async def post(self, request_body) -> JSONResponse:
7171
adapter.check(data, project.id, user.id)
7272

7373
project_manager.add_workflow_store_data_to_project(
74-
user_id=user.id, project_id=project.id, file_name=name, data=data
74+
user_id=user.id,
75+
project_id=project.id,
76+
org_id=project.organization_id,
77+
file_name=name,
78+
data=data,
7579
)
7680

77-
task_queue_manager.add_task(
78-
str(project.id),
79-
TaskType.TOKENIZATION,
81+
task_master_manager.queue_task(
82+
str(organization.id),
8083
str(user.id),
84+
TaskType.TOKENIZATION,
8185
{
8286
"scope": RecordTokenizationScope.PROJECT.value,
8387
"include_rats": True,
8488
"only_uploaded_attributes": False,
89+
"project_id": str(project.id),
8590
},
8691
)
8792

api/transfer.py

Lines changed: 9 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import traceback
33
import time
4-
from typing import Optional, Dict
4+
from typing import Optional
55
from starlette.endpoints import HTTPEndpoint
66
from starlette.responses import PlainTextResponse, JSONResponse
77
from controller.embedding.manager import recreate_embeddings
@@ -18,12 +18,11 @@
1818
general,
1919
organization,
2020
tokenization,
21-
project as refinery_project,
21+
project,
2222
)
2323

2424
from submodules.model.cognition_objects import (
2525
project as cognition_project,
26-
macro as macro_db_bo,
2726
)
2827

2928
from controller.transfer import manager as transfer_manager
@@ -41,7 +40,7 @@
4140
from util import daemon, notification
4241
from controller.transfer.cognition.minio_upload import handle_cognition_file_upload
4342

44-
from controller.task_queue import manager as task_queue_manager
43+
from controller.task_master import manager as task_master_manager
4544
from submodules.model.enums import TaskType, RecordTokenizationScope
4645

4746

@@ -243,86 +242,6 @@ def put(self, request) -> PlainTextResponse:
243242
return PlainTextResponse("OK")
244243

245244

246-
class CognitionParseMarkdownFile(HTTPEndpoint):
247-
def post(self, request) -> PlainTextResponse:
248-
refinery_project_id = request.path_params["project_id"]
249-
refinery_project_item = refinery_project.get(refinery_project_id)
250-
if not refinery_project_item:
251-
return PlainTextResponse("Bad project id", status_code=400)
252-
253-
dataset_id = request.path_params["dataset_id"]
254-
file_id = request.path_params["file_id"]
255-
256-
# via thread to ensure the endpoint returns immediately
257-
258-
daemon.run(
259-
CognitionParseMarkdownFile.__add_parse_markdown_file_thread,
260-
refinery_project_id,
261-
str(refinery_project_item.created_by),
262-
{
263-
"org_id": str(refinery_project_item.organization_id),
264-
"dataset_id": dataset_id,
265-
"file_id": file_id,
266-
},
267-
)
268-
269-
return PlainTextResponse("OK")
270-
271-
def __add_parse_markdown_file_thread(
272-
project_id: str, user_id: str, task_info: Dict[str, str]
273-
):
274-
275-
ctx_token = general.get_ctx_token()
276-
try:
277-
task_queue_manager.add_task(
278-
project_id, TaskType.PARSE_MARKDOWN_FILE, user_id, task_info
279-
)
280-
finally:
281-
general.remove_and_refresh_session(ctx_token, False)
282-
283-
284-
class CognitionStartMacroExecutionGroup(HTTPEndpoint):
285-
def put(self, request) -> PlainTextResponse:
286-
macro_id = request.path_params["macro_id"]
287-
group_id = request.path_params["group_id"]
288-
289-
execution_entries = macro_db_bo.get_all_macro_executions(macro_id, group_id)
290-
291-
if len(execution_entries) == 0:
292-
return PlainTextResponse("No executions found", status_code=400)
293-
if not (cognition_prj_id := execution_entries[0].meta_info.get("project_id")):
294-
return PlainTextResponse("No project id found", status_code=400)
295-
cognition_prj = cognition_project.get(cognition_prj_id)
296-
refinery_prj_id = str(
297-
refinery_project.get_or_create_queue_project(
298-
cognition_prj.organization_id, cognition_prj.created_by, True
299-
).id
300-
)
301-
cached = {str(e.id): str(e.created_by) for e in execution_entries}
302-
303-
def queue_tasks():
304-
token = general.get_ctx_token()
305-
try:
306-
for exec_id in cached:
307-
task_queue_manager.add_task(
308-
refinery_prj_id,
309-
TaskType.RUN_COGNITION_MACRO,
310-
cached[exec_id],
311-
{
312-
"macro_id": macro_id,
313-
"execution_id": exec_id,
314-
"execution_group_id": group_id,
315-
},
316-
)
317-
general.commit()
318-
finally:
319-
general.remove_and_refresh_session(token, False)
320-
321-
daemon.run(queue_tasks)
322-
323-
return PlainTextResponse("OK")
324-
325-
326245
class AssociationsImport(HTTPEndpoint):
327246
async def post(self, request) -> JSONResponse:
328247
project_id = request.path_params["project_id"]
@@ -404,11 +323,14 @@ def init_file_import(task: UploadTask, project_id: str, is_global_update: bool)
404323
)
405324
if task.file_type != "knowledge_base":
406325
only_usable_attributes = task.file_type == "records_add"
407-
task_queue_manager.add_task(
408-
project_id,
326+
project_item = project.get(project_id)
327+
org_id = project_item.organization_id
328+
task_master_manager.queue_task(
329+
str(org_id),
330+
str(task.user_id),
409331
TaskType.TOKENIZATION,
410-
task.user_id,
411332
{
333+
"project_id": str(project_id),
412334
"scope": RecordTokenizationScope.PROJECT.value,
413335
"include_rats": True,
414336
"only_uploaded_attributes": only_usable_attributes,

app.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
UploadTaskInfo,
1616
CognitionImport,
1717
CognitionPrepareProject,
18-
CognitionParseMarkdownFile,
19-
CognitionStartMacroExecutionGroup,
2018
)
2119
from fast_api.routes.organization import router as org_router
2220
from fast_api.routes.project import router as project_router
@@ -36,12 +34,12 @@
3634
from fast_api.routes.record import router as record_router
3735
from fast_api.routes.weak_supervision import router as weak_supervision_router
3836
from fast_api.routes.labeling_tasks import router as labeling_tasks_router
37+
from fast_api.routes.task_execution import router as task_execution_router
3938
from middleware.database_session import handle_db_session
4039
from middleware.starlette_tmp_middleware import DatabaseSessionHandler
4140
from starlette.applications import Starlette
4241
from starlette.routing import Route, Mount
4342

44-
from controller.task_queue.task_queue import init_task_queues
4543
from controller.project.manager import check_in_deletion_projects
4644
from route_prefix import (
4745
PREFIX_ORGANIZATION,
@@ -62,6 +60,7 @@
6260
PREFIX_RECORD,
6361
PREFIX_WEAK_SUPERVISION,
6462
PREFIX_LABELING_TASKS,
63+
PREFIX_TASK_EXECUTION,
6564
)
6665
from util import security, clean_up
6766
from middleware import log_storage
@@ -116,6 +115,10 @@
116115
labeling_tasks_router, prefix=PREFIX_LABELING_TASKS, tags=["labeling-tasks"]
117116
)
118117

118+
fastapi_app_internal = FastAPI()
119+
fastapi_app_internal.include_router(
120+
task_execution_router, prefix=PREFIX_TASK_EXECUTION, tags=["task-execution"]
121+
)
119122
routes = [
120123
Route("/notify/{path:path}", Notify),
121124
Route("/healthcheck", Healthcheck),
@@ -135,19 +138,14 @@
135138
"/project/{cognition_project_id:str}/cognition/continue/{task_id:str}/finalize",
136139
CognitionPrepareProject,
137140
),
138-
Route(
139-
"/project/{project_id:str}/cognition/datasets/{dataset_id:str}/files/{file_id:str}/queue",
140-
CognitionParseMarkdownFile,
141-
),
142141
Route("/project/{project_id:str}/import/task/{task_id:str}", UploadTaskInfo),
143142
Route("/project", ProjectCreationFromWorkflow),
144-
Route(
145-
"/macro/{macro_id:str}/execution-group/{group_id:str}/queue",
146-
CognitionStartMacroExecutionGroup,
147-
),
148143
Route("/is_managed", IsManagedRest),
149144
Route("/is_demo", IsDemoRest),
150145
Mount("/api", app=fastapi_app, name="REST API"),
146+
Mount(
147+
"/internal/api", app=fastapi_app_internal, name="INTERNAL REST API"
148+
), # task master requests
151149
]
152150

153151

@@ -156,7 +154,6 @@
156154
middleware = [Middleware(DatabaseSessionHandler)]
157155
app = Starlette(routes=routes, middleware=middleware)
158156

159-
init_task_queues()
160157
check_in_deletion_projects()
161158
security.check_secret_key()
162159
clean_up.clean_up_database()

0 commit comments

Comments
 (0)