Skip to content

Commit a0e47e4

Browse files
SimonDegrafKernJohannes HötterJWittmeyerlumburovskalina
authored
Embedding providers (#136)
* forward platform information * Adds agreement, gdpr comliant flsgs and embedding provider support * Adds created by to embeddings * Merges embedding logic for different embedding types and add support more embedding providers * Change order of embedding platforms * Added platform as part of the Encoder type * Adds modularization for recreation of embeddings * Refactors recreation of embeddings to solve dependency problems * Embedding recreation for adding new records * Fixes embedding deletion * Fixes small things * Adds notification for organization change * Adds better state management for reupload of records * Adds optimized state management for embeddings * Standardize call to recreate embeddings * Resolves PR comments * Removes print statement * Adds commit for embedding creation * Adds handling for missing tokenization * Adds logic to infer embedding information out of old projects * Adds order of project transfer so that source code can be replaced by new embedding name * Changes call logic of agreements * Resolves first few PR comments * Resolves PR comments * Adds new term text * Changed terms text slightly * Adds link and placeholder * Added link to embedding type * Update controller/transfer/project_transfer_manager.py Co-authored-by: JWittmeyer <[email protected]> * Resolves typo * Resolves typo * Submodules merge * Drone --------- Co-authored-by: Johannes Hötter <[email protected]> Co-authored-by: JWittmeyer <[email protected]> Co-authored-by: Lina <[email protected]>
1 parent 3dff874 commit a0e47e4

File tree

14 files changed

+507
-385
lines changed

14 files changed

+507
-385
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Adds agreements, gdpr flags and platform support for embeddings
2+
3+
Revision ID: 1a25c862801f
4+
Revises: 03d19eada266
5+
Create Date: 2023-06-06 13:58:16.634066
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
from sqlalchemy.dialects import postgresql
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '1a25c862801f'
14+
down_revision = '03d19eada266'
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.create_table('agreement',
22+
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
23+
sa.Column('project_id', postgresql.UUID(as_uuid=True), nullable=True),
24+
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=True),
25+
sa.Column('xfkey', postgresql.UUID(as_uuid=True), nullable=True),
26+
sa.Column('xftype', sa.String(), nullable=True),
27+
sa.Column('terms_text', sa.String(), nullable=True),
28+
sa.Column('terms_accepted', sa.Boolean(), nullable=True),
29+
sa.Column('created_at', sa.DateTime(), nullable=True),
30+
sa.ForeignKeyConstraint(['project_id'], ['project.id'], ondelete='CASCADE'),
31+
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
32+
sa.PrimaryKeyConstraint('id')
33+
)
34+
op.create_index(op.f('ix_agreement_project_id'), 'agreement', ['project_id'], unique=False)
35+
op.create_index(op.f('ix_agreement_user_id'), 'agreement', ['user_id'], unique=False)
36+
op.create_index(op.f('ix_agreement_xfkey'), 'agreement', ['xfkey'], unique=False)
37+
op.create_index(op.f('ix_agreement_xftype'), 'agreement', ['xftype'], unique=False)
38+
op.add_column('embedding', sa.Column('created_by', postgresql.UUID(as_uuid=True), nullable=True))
39+
op.add_column('embedding', sa.Column('api_token', sa.String(), nullable=True))
40+
op.add_column('embedding', sa.Column('model', sa.String(), nullable=True))
41+
op.add_column('embedding', sa.Column('platform', sa.String(), nullable=True))
42+
op.create_index(op.f('ix_embedding_created_by'), 'embedding', ['created_by'], unique=False)
43+
op.create_foreign_key(None, 'embedding', 'user', ['created_by'], ['id'], ondelete='CASCADE')
44+
op.add_column('organization', sa.Column('gdpr_compliant', sa.Boolean(), nullable=True))
45+
# ### end Alembic commands ###
46+
47+
48+
def downgrade():
49+
# ### commands auto generated by Alembic - please adjust! ###
50+
op.drop_column('organization', 'gdpr_compliant')
51+
op.drop_constraint(None, 'embedding', type_='foreignkey')
52+
op.drop_index(op.f('ix_embedding_created_by'), table_name='embedding')
53+
op.drop_column('embedding', 'platform')
54+
op.drop_column('embedding', 'model')
55+
op.drop_column('embedding', 'api_token')
56+
op.drop_column('embedding', 'created_by')
57+
op.drop_index(op.f('ix_agreement_xftype'), table_name='agreement')
58+
op.drop_index(op.f('ix_agreement_xfkey'), table_name='agreement')
59+
op.drop_index(op.f('ix_agreement_user_id'), table_name='agreement')
60+
op.drop_index(op.f('ix_agreement_project_id'), table_name='agreement')
61+
op.drop_table('agreement')
62+
# ### end Alembic commands ###

api/transfer.py

Lines changed: 15 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from controller.embedding import connector as embedding_connector
99
from starlette.endpoints import HTTPEndpoint
1010
from starlette.responses import PlainTextResponse, JSONResponse
11+
from controller.embedding.manager import recreate_embeddings
1112

1213
from controller.transfer.labelstudio import import_preperator
1314
from submodules.s3 import controller as s3
@@ -234,7 +235,12 @@ def init_file_import(task: UploadTask, project_id: str, is_global_update: bool)
234235
import_preperator.prepare_label_studio_import(project_id, task)
235236
else:
236237
transfer_manager.import_records_from_file(project_id, task)
237-
calculate_missing_attributes(project_id, task.user_id)
238+
daemon.run(
239+
__recalculate_missing_attributes_and_embeddings,
240+
project_id,
241+
str(task.user_id),
242+
)
243+
238244
elif "project" in task.file_type:
239245
transfer_manager.import_project(project_id, task)
240246
elif "knowledge_base" in task.file_type:
@@ -284,12 +290,9 @@ def file_import_error_handling(
284290
)
285291

286292

287-
def calculate_missing_attributes(project_id: str, user_id: str) -> None:
288-
daemon.run(
289-
__calculate_missing_attributes,
290-
project_id,
291-
user_id,
292-
)
293+
def __recalculate_missing_attributes_and_embeddings(project_id: str, user_id: str) -> None:
294+
__calculate_missing_attributes(project_id, user_id)
295+
recreate_embeddings(project_id)
293296

294297

295298
def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
@@ -305,6 +308,7 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
305308
)
306309
if len(attributes_usable) == 0:
307310
return
311+
308312
# stored as list so connection results do not affect
309313
attribute_ids = [str(att_usable.id) for att_usable in attributes_usable]
310314
for att_id in attribute_ids:
@@ -313,7 +317,6 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
313317
notification.send_organization_update(
314318
project_id=project_id, message="calculate_attribute:started:all"
315319
)
316-
317320
try:
318321
# first check project tokenization completed
319322
i = 0
@@ -323,7 +326,7 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
323326
i = 0
324327
ctx_token = general.remove_and_refresh_session(ctx_token, True)
325328
if tokenization.is_doc_bin_creation_running(project_id):
326-
time.sleep(5)
329+
time.sleep(2)
327330
continue
328331
else:
329332
break
@@ -350,15 +353,15 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
350353
if tokenization.is_doc_bin_creation_running_for_attribute(
351354
project_id, current_att.name
352355
):
353-
time.sleep(5)
356+
time.sleep(2)
354357
continue
355358
else:
356359
attribute_ids.pop(0)
357360
notification.send_organization_update(
358361
project_id=project_id,
359362
message=f"calculate_attribute:finished:{current_att_id}",
360363
)
361-
time.sleep(5)
364+
time.sleep(2)
362365
except Exception as e:
363366
print(
364367
f"Error while recreating attribute calculation for {project_id} when new records are uploaded : {e}"
@@ -381,80 +384,4 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
381384
message="calculate_attribute:finished:all",
382385
)
383386
general.remove_and_refresh_session(ctx_token, False)
384-
calculate_missing_embedding_tensors(project_id, user_id)
385-
386-
387-
def calculate_missing_embedding_tensors(project_id: str, user_id: str) -> None:
388-
daemon.run(
389-
__calculate_missing_embedding_tensors,
390-
project_id,
391-
user_id,
392-
)
393-
394-
395-
def __calculate_missing_embedding_tensors(project_id: str, user_id: str) -> None:
396-
ctx_token = general.get_ctx_token()
397-
embeddings = embedding.get_finished_embeddings_by_started_at(project_id)
398-
if len(embeddings) == 0:
399-
return
400-
401-
embedding_ids = [str(embed.id) for embed in embeddings]
402-
for embed_id in embedding_ids:
403-
embedding.update_embedding_state_waiting(project_id, embed_id)
404-
general.commit()
405-
406-
try:
407-
ctx_token = __create_embeddings(project_id, embedding_ids, user_id, ctx_token)
408-
except Exception as e:
409-
print(
410-
f"Error while recreating embeddings for {project_id} when new records are uploaded : {e}"
411-
)
412-
get_waiting_embeddings = embedding.get_waiting_embeddings(project_id)
413-
for embed in get_waiting_embeddings:
414-
embedding.update_embedding_state_failed(project_id, str(embed.id))
415-
general.commit()
416-
finally:
417-
notification.send_organization_update(
418-
project_id=project_id, message="embedding:finished:all"
419-
)
420-
general.remove_and_refresh_session(ctx_token, False)
421-
422-
423-
def __create_embeddings(
424-
project_id: str,
425-
embedding_ids: List[str],
426-
user_id: str,
427-
ctx_token: Any,
428-
) -> Any:
429-
notification.send_organization_update(
430-
project_id=project_id, message="embedding:started:all"
431-
)
432-
for embedding_id in embedding_ids:
433-
ctx_token = general.remove_and_refresh_session(ctx_token, request_new=True)
434-
embedding_item = embedding.get(project_id, embedding_id)
435-
if not embedding_item:
436-
continue
437-
438-
embedding_connector.request_deleting_embedding(project_id, embedding_id)
439-
440-
attribute_id = str(embedding_item.attribute_id)
441-
attribute_name = attribute.get(project_id, attribute_id).name
442-
if embedding_item.type == enums.EmbeddingType.ON_ATTRIBUTE.value:
443-
prefix = f"{attribute_name}-classification-"
444-
config_string = embedding_item.name[len(prefix) :]
445-
embedding_connector.request_creating_attribute_level_embedding(
446-
project_id, attribute_id, user_id, config_string
447-
)
448-
else:
449-
prefix = f"{attribute_name}-extraction-"
450-
config_string = embedding_item.name[len(prefix) :]
451-
embedding_connector.request_creating_token_level_embedding(
452-
project_id, attribute_id, user_id, config_string
453-
)
454-
time.sleep(5)
455-
while embedding_util.has_encoder_running(project_id):
456-
if embedding_item.state == enums.EmbeddingState.WAITING.value:
457-
break
458-
time.sleep(1)
459-
return ctx_token
460-
387+

controller/embedding/connector.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,13 @@ def request_listing_recommended_encoders() -> Any:
1010
url = f"{BASE_URI}/classification/recommend/TEXT" # TODO does here have to be a data type?
1111
return service_requests.get_call_or_raise(url)
1212

13-
14-
def request_creating_attribute_level_embedding(
15-
project_id: str, attribute_id: str, user_id: str, config_string: str
16-
) -> Any:
17-
url = f"{BASE_URI}/classification/encode"
18-
data = {
19-
"project_id": str(project_id),
20-
"attribute_id": str(attribute_id),
21-
"user_id": str(user_id),
22-
"config_string": config_string,
23-
}
24-
return service_requests.post_call_or_raise(url, data)
25-
26-
27-
def request_creating_token_level_embedding(
28-
project_id: str, attribute_id: str, user_id: str, config_string: str
13+
def request_embedding(
14+
project_id: str, embedding_id: str
2915
) -> Any:
30-
url = f"{BASE_URI}/extraction/encode"
16+
url = f"{BASE_URI}/embed"
3117
data = {
3218
"project_id": str(project_id),
33-
"attribute_id": str(attribute_id),
34-
"user_id": str(user_id),
35-
"config_string": config_string,
19+
"embedding_id": str(embedding_id),
3620
}
3721
return service_requests.post_call_or_raise(url, data)
3822

0 commit comments

Comments
 (0)