Skip to content

Commit 3777401

Browse files
Adds basic queue (#125)
* Adds basic queue for embeddings * Adds queue for information sources * Removes user object from doc ock to prevent session issues * Adds tokenizatoin queue * Adds attribute calculation * Comments * Adds attribute tokenization to queueing logic * PR comments * Better error message * Adds PR comments * changes condition for tokenization start task --------- Co-authored-by: FelixKirschKern <[email protected]>
1 parent 9b4d961 commit 3777401

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+922
-89
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Adds task queue table
2+
3+
Revision ID: bb87177d46b5
4+
Revises: 546e5cd7feaa
5+
Create Date: 2023-04-26 10:03:46.597003
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
from sqlalchemy.dialects import postgresql
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "bb87177d46b5"
14+
down_revision = "546e5cd7feaa"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.create_table(
22+
"task_queue",
23+
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
24+
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=True),
25+
sa.Column("task_type", sa.String(), nullable=True),
26+
sa.Column("task_info", sa.JSON(), nullable=True),
27+
sa.Column("priority", sa.Boolean(), nullable=True),
28+
sa.Column("is_active", sa.Boolean(), nullable=True),
29+
sa.Column("created_at", sa.DateTime(), nullable=True),
30+
sa.Column("created_by", postgresql.UUID(as_uuid=True), nullable=True),
31+
sa.ForeignKeyConstraint(
32+
["created_by"],
33+
["user.id"],
34+
),
35+
sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"),
36+
sa.PrimaryKeyConstraint("id"),
37+
)
38+
# ### end Alembic commands ###
39+
40+
41+
def downgrade():
42+
# ### commands auto generated by Alembic - please adjust! ###
43+
op.drop_table("task_queue")
44+
# ### end Alembic commands ###

api/project.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from submodules.s3.controller import bucket_exists, create_bucket
1313
from util import doc_ock, notification, adapter
1414

15+
from controller.task_queue import manager as task_queue_manager
16+
from submodules.model.enums import TaskType, RecordTokenizationScope
17+
1518
logging.basicConfig(level=logging.DEBUG)
1619

1720

@@ -74,13 +77,22 @@ async def post(self, request_body) -> JSONResponse:
7477
user_id=user.id, project_id=project.id, file_name=name, data=data
7578
)
7679

77-
tokenization_service.request_tokenize_project(str(project.id), str(user.id))
80+
task_queue_manager.add_task(
81+
str(project.id),
82+
TaskType.TOKENIZATION,
83+
str(user.id),
84+
{
85+
"scope": RecordTokenizationScope.PROJECT.value,
86+
"include_rats": True,
87+
"only_uploaded_attributes": False,
88+
},
89+
)
7890

7991
notification.send_organization_update(
8092
project.id, f"project_created:{str(project.id)}", True
8193
)
8294
doc_ock.post_event(
83-
user,
95+
str(user.id),
8496
events.CreateProject(Name=f"{name}-{project.id}", Description=description),
8597
)
8698

api/transfer.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from starlette.responses import PlainTextResponse, JSONResponse
1111

1212
from controller.transfer.labelstudio import import_preperator
13-
from submodules.model.business_objects.tokenization import is_doc_bin_creation_running
1413
from submodules.s3 import controller as s3
1514
from submodules.model.business_objects import (
1615
attribute,
@@ -31,10 +30,13 @@
3130

3231
from submodules.model import enums, exceptions
3332
from util.notification import create_notification
34-
from submodules.model.enums import AttributeState, NotificationType, UploadStates
35-
from submodules.model.models import Embedding, UploadTask
33+
from submodules.model.enums import NotificationType
34+
from submodules.model.models import UploadTask
3635
from util import daemon, notification
37-
from controller.tokenization import tokenization_service
36+
37+
from controller.task_queue import manager as task_queue_manager
38+
from submodules.model.enums import TaskType, RecordTokenizationScope
39+
3840

3941
logging.basicConfig(level=logging.DEBUG)
4042
logger = logging.getLogger(__name__)
@@ -247,8 +249,15 @@ def init_file_import(task: UploadTask, project_id: str, is_global_update: bool)
247249
)
248250
if task.file_type != "knowledge_base":
249251
only_usable_attributes = task.file_type == "records_add"
250-
tokenization_service.request_tokenize_project(
251-
project_id, str(task.user_id), True, only_usable_attributes
252+
task_queue_manager.add_task(
253+
project_id,
254+
TaskType.TOKENIZATION,
255+
task.user_id,
256+
{
257+
"scope": RecordTokenizationScope.PROJECT.value,
258+
"include_rats": True,
259+
"only_uploaded_attributes": only_usable_attributes,
260+
},
252261
)
253262

254263

app.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from starlette.routing import Route
1919

2020
from graphql_api import schema
21+
from controller.task_queue.task_queue import init_task_queue
22+
2123

2224
logging.basicConfig(level=logging.DEBUG)
2325
logger = logging.getLogger(__name__)
@@ -48,3 +50,5 @@
4850
middleware = [Middleware(DatabaseSessionHandler)]
4951

5052
app = Starlette(routes=routes, middleware=middleware)
53+
54+
init_task_queue()

controller/attribute/manager.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
from typing import List, Tuple
22
from controller.tokenization.tokenization_service import (
3-
request_tokenize_calculated_attribute,
4-
request_tokenize_project,
53
request_reupload_docbins,
64
)
7-
from submodules.model.business_objects import attribute, record, tokenization, general
5+
from submodules.model.business_objects import (
6+
attribute,
7+
record,
8+
tokenization,
9+
general,
10+
)
811
from submodules.model.models import Attribute
9-
from submodules.model.enums import AttributeState, DataTypes
12+
from submodules.model.enums import AttributeState, DataTypes, RecordTokenizationScope
1013
from util import daemon, notification
1114

15+
from controller.task_queue import manager as task_queue_manager
16+
from submodules.model.enums import TaskType
1217
from . import util
1318
from sqlalchemy import sql
1419

@@ -141,10 +146,15 @@ def add_running_id(
141146
project_id, attribute_name, for_retokenization, with_commit=True
142147
)
143148
if for_retokenization:
144-
daemon.run(
145-
request_tokenize_project,
149+
task_queue_manager.add_task(
146150
project_id,
151+
TaskType.TOKENIZATION,
147152
user_id,
153+
{
154+
"scope": RecordTokenizationScope.PROJECT.value,
155+
"include_rats": True,
156+
"only_uploaded_attributes": False,
157+
},
148158
)
149159

150160

@@ -261,9 +271,17 @@ def __calculate_user_attribute_all_records(
261271
project_id, attribute_id, "Triggering tokenization."
262272
)
263273
try:
264-
request_tokenize_calculated_attribute(
265-
project_id, user_id, attribute_item.id, include_rats
274+
task_queue_manager.add_task(
275+
project_id,
276+
TaskType.TOKENIZATION,
277+
user_id,
278+
{
279+
"scope": RecordTokenizationScope.ATTRIBUTE.value,
280+
"attribute_id": str(attribute_item.id),
281+
"include_rats": include_rats,
282+
},
266283
)
284+
267285
except Exception:
268286
record.delete_user_created_attribute(
269287
project_id=project_id,

controller/embedding/manager.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import connector
88
from controller.misc import manager as misc
99
from controller.model_provider import manager as model_manager
10+
from submodules.model.business_objects import attribute
1011

1112

1213
def get_recommended_encoders() -> List[Any]:
@@ -112,3 +113,25 @@ def __embed_one_by_one_helper(
112113
time.sleep(5)
113114
while util.has_encoder_running(project_id):
114115
time.sleep(5)
116+
117+
118+
def get_embedding_name(
119+
project_id: str, attribute_id: str, level: str, embedding_handle: str
120+
) -> str:
121+
if level not in [
122+
enums.EmbeddingType.ON_ATTRIBUTE.value,
123+
enums.EmbeddingType.ON_TOKEN.value,
124+
]:
125+
raise ValueError("level must be either attribute or token")
126+
embedding_type = (
127+
"classification"
128+
if level == enums.EmbeddingType.ON_ATTRIBUTE.value
129+
else "extraction"
130+
)
131+
132+
attribute_item = attribute.get(project_id, attribute_id)
133+
if attribute_item is None:
134+
raise ValueError("attribute not found")
135+
attribute_name = attribute_item.name
136+
137+
return f"{attribute_name}-{embedding_type}-{embedding_handle}"

controller/notification/notification_data.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,13 @@
158158
"page": enums.Pages.SETTINGS.value,
159159
"docs": enums.DOCS.INFORMATION_SOURCES.value,
160160
},
161+
enums.NotificationType.INFORMATION_SOURCE_S3_DOCBIN_MISSING.value: {
162+
"message_template": "Tokenization docs missing in S3 storage. Docs are present once tokenization process is started (not queued).",
163+
"title": "Heuristic execution",
164+
"level": enums.Notification.ERROR.value,
165+
"page": enums.Pages.SETTINGS.value,
166+
"docs": enums.DOCS.INFORMATION_SOURCES.value,
167+
},
161168
enums.NotificationType.WEAK_SUPERVISION_TASK_STARTED.value: {
162169
"message_template": "Started weak supervision.",
163170
"title": "Weak supervision",

controller/payload/payload_scheduler.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def create_payload(
9696
)
9797

9898
def prepare_and_run_execution_pipeline(
99-
user: User,
10099
payload_id: str,
101100
project_id: str,
102101
information_source_item: InformationSource,
@@ -107,7 +106,6 @@ def prepare_and_run_execution_pipeline(
107106
information_source_item
108107
)
109108
execution_pipeline(
110-
user,
111109
payload_id,
112110
project_id,
113111
information_source_item,
@@ -132,11 +130,19 @@ def prepare_and_run_execution_pipeline(
132130
def prepare_input_data_for_payload(
133131
information_source_item: InformationSource,
134132
) -> Tuple[str, Dict[str, Any]]:
133+
org_id = organization.get_id_by_project_id(project_id)
135134
if (
136135
information_source_item.type
137136
== enums.InformationSourceType.LABELING_FUNCTION.value
138137
):
139-
# isn't collected every time but rather whenever tokenization needs to run again --> accesslink to the docbin file on s3
138+
# check if docbins exist
139+
if not s3.object_exists(org_id, project_id + "/" + "docbin_full"):
140+
notification = create_notification(
141+
enums.NotificationType.INFORMATION_SOURCE_S3_DOCBIN_MISSING,
142+
user_id,
143+
project_id,
144+
)
145+
raise ValueError(notification.message)
140146
return None, None
141147

142148
elif (
@@ -158,7 +164,6 @@ def prepare_input_data_for_payload(
158164
)
159165
embedding_file_name = f"embedding_tensors_{embedding_id}.csv.bz2"
160166
embedding_item = embedding.get(project_id, embedding_id)
161-
org_id = organization.get_id_by_project_id(project_id)
162167
if not s3.object_exists(org_id, project_id + "/" + embedding_file_name):
163168
notification = create_notification(
164169
enums.NotificationType.INFORMATION_SOURCE_S3_EMBEDDING_MISSING,
@@ -200,7 +205,6 @@ def prepare_input_data_for_payload(
200205
return embedding_file_name, input_data
201206

202207
def execution_pipeline(
203-
user: User,
204208
payload_id: str,
205209
project_id: str,
206210
information_source_item: InformationSource,
@@ -309,7 +313,7 @@ def execution_pipeline(
309313

310314
project_item = project.get(project_id)
311315
doc_ock.post_event(
312-
user,
316+
user_id,
313317
events.AddInformationSourceRun(
314318
ProjectName=f"{project_item.name}-{project_item.id}",
315319
Type=information_source_item.type,
@@ -319,18 +323,15 @@ def execution_pipeline(
319323
),
320324
)
321325

322-
user = user_manager.get_user(user_id)
323326
if asynchronous:
324327
daemon.run(
325328
prepare_and_run_execution_pipeline,
326-
user,
327329
payload.id,
328330
project_id,
329331
information_source_item,
330332
)
331333
else:
332334
prepare_and_run_execution_pipeline(
333-
user,
334335
payload.id,
335336
project_id,
336337
information_source_item,
@@ -468,20 +469,33 @@ def read_container_logs_thread(
468469
payload_id: str,
469470
docker_container: Any,
470471
):
472+
473+
ctx_token = general.get_ctx_token()
471474
# needs to be refetched since it is not thread safe
472475
information_source_payload = information_source.get_payload(project_id, payload_id)
473476
previous_progress = -1
474477
last_timestamp = None
478+
c = 0
475479
while name in __containers_running:
476480
time.sleep(1)
481+
c += 1
482+
if c > 100:
483+
ctx_token = general.remove_and_refresh_session(ctx_token, True)
484+
information_source_payload = information_source.get_payload(
485+
project_id, payload_id
486+
)
477487
if not name in __containers_running:
478488
break
479-
log_lines = docker_container.logs(
480-
stdout=True,
481-
stderr=True,
482-
timestamps=True,
483-
since=last_timestamp,
484-
)
489+
try:
490+
log_lines = docker_container.logs(
491+
stdout=True,
492+
stderr=True,
493+
timestamps=True,
494+
since=last_timestamp,
495+
)
496+
except:
497+
# failsafe for containers that shut down during the read
498+
break
485499
current_logs = [
486500
l for l in str(log_lines.decode("utf-8")).split("\n") if len(l.strip()) > 0
487501
]
@@ -506,6 +520,7 @@ def read_container_logs_thread(
506520
set_payload_progress(
507521
project_id, information_source_payload, last_entry, factor=0.8
508522
)
523+
general.remove_and_refresh_session(ctx_token)
509524

510525

511526
def get_inference_dir() -> str:

0 commit comments

Comments
 (0)