Skip to content

Commit 5ba1e12

Browse files
authored
Embedding list prep (#148)
* Embedding lists
1 parent 7118ae7 commit 5ba1e12

File tree

8 files changed

+70
-21
lines changed

8 files changed

+70
-21
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""adds embedding tensor subkey
2+
3+
Revision ID: 53c561be097d
4+
Revises: 0714589d508e
5+
Create Date: 2023-08-28 08:50:20.167644
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '53c561be097d'
14+
down_revision = '0714589d508e'
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.add_column('embedding_tensor', sa.Column('sub_key', sa.Integer(), nullable=True))
22+
# ### end Alembic commands ###
23+
24+
25+
def downgrade():
26+
# ### commands auto generated by Alembic - please adjust! ###
27+
op.drop_column('embedding_tensor', 'sub_key')
28+
# ### end Alembic commands ###

controller/attribute/manager.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22
from controller.tokenization.tokenization_service import (
33
request_reupload_docbins,
44
)
5+
import json
56
from submodules.model.business_objects import (
67
attribute,
78
record,
89
tokenization,
910
general,
1011
)
1112
from submodules.model.models import Attribute
12-
from submodules.model.enums import AttributeState, DataTypes, RecordTokenizationScope
13+
from submodules.model.enums import (
14+
AttributeState,
15+
DataTypes,
16+
RecordTokenizationScope,
17+
AttributeVisibility,
18+
)
1319
from util import daemon, notification
1420

1521
from controller.task_queue import manager as task_queue_manager
@@ -68,6 +74,9 @@ def create_user_attribute(project_id: str, name: str, data_type: str) -> Attribu
6874
relative_position = 1
6975
else:
7076
relative_position = prev_relative_position + 1
77+
visibility = None # default
78+
if data_type == DataTypes.EMBEDDING_LIST.value:
79+
visibility = AttributeVisibility.HIDE.value
7180

7281
attribute_item: Attribute = attribute.create(
7382
project_id,
@@ -77,6 +86,7 @@ def create_user_attribute(project_id: str, name: str, data_type: str) -> Attribu
7786
is_primary_key=False,
7887
user_created=True,
7988
state=AttributeState.INITIAL.value,
89+
visibility=visibility,
8090
with_commit=True,
8191
)
8292
notification.send_organization_update(
@@ -355,4 +365,15 @@ def calculate_user_attribute_sample_records(
355365
calculated_attributes = util.run_attribute_calculation_exec_env(
356366
attribute_id=attribute_id, project_id=project_id, doc_bin=doc_bin_samples
357367
)
358-
return list(calculated_attributes.keys()), list(calculated_attributes.values())
368+
values = None
369+
if (
370+
attribute.get(project_id, attribute_id).data_type
371+
== DataTypes.EMBEDDING_LIST.value
372+
):
373+
# values are json serialized so they can be easily transferred to the frontend.
374+
# Since the return type is a list of strings, without json.dumps a str(xxxx) will be called
375+
# which can't be easily deserialized if special characters are in the string
376+
values = [json.dumps(v) for v in list(calculated_attributes.values())]
377+
else:
378+
values = list(calculated_attributes.values())
379+
return list(calculated_attributes.keys()), values

controller/record/manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ def get_records_by_similarity_search(
2525
embedding_id: str,
2626
record_id: str,
2727
att_filter: Optional[List[Dict[str, Any]]] = None,
28+
record_sub_key: Optional[int] = None,
2829
) -> ExtendedSearch:
2930
record_ids = neural_search_connector.request_most_similar_record_ids(
30-
project_id, embedding_id, record_id, 100, att_filter
31+
project_id, embedding_id, record_id, 100, att_filter, record_sub_key
3132
)
3233
if not len(record_ids):
3334
record_ids = [record_id]

controller/record/neural_search_connector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ def request_most_similar_record_ids(
1212
record_id: str,
1313
limit: int,
1414
att_filter: Optional[List[Dict[str, Any]]] = None,
15+
record_sub_key: Optional[int] = None,
1516
) -> List[str]:
1617
url = f"{BASE_URI}/most_similar?project_id={project_id}&embedding_id={embedding_id}&record_id={record_id}&limit={limit}"
18+
if record_sub_key is not None:
19+
url += f"&record_sub_key={record_sub_key}"
1720

1821
result = service_requests.post_call_or_raise(url, att_filter)
1922
return result

controller/transfer/project_transfer_manager.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -150,30 +150,20 @@ def import_file(
150150
send_progress_update_throttle(project_id, task_id, 0)
151151
project_item = project.get(project_id)
152152
if not project_item.name:
153-
project_item.name = data.get(
154-
"project_details_data",
155-
).get(
153+
project_item.name = data.get("project_details_data",).get(
156154
"name",
157155
)
158-
project_item.description = data.get(
159-
"project_details_data",
160-
).get(
156+
project_item.description = data.get("project_details_data",).get(
161157
"description",
162158
)
163-
project_item.tokenizer = data.get(
164-
"project_details_data",
165-
).get(
159+
project_item.tokenizer = data.get("project_details_data",).get(
166160
"tokenizer",
167161
)
168-
spacy_language = data.get(
169-
"project_details_data",
170-
).get(
162+
spacy_language = data.get("project_details_data",).get(
171163
"tokenizer",
172164
)[:2]
173165
project_item.tokenizer_blank = spacy_language
174-
project_item.status = data.get(
175-
"project_details_data",
176-
).get(
166+
project_item.status = data.get("project_details_data",).get(
177167
"status",
178168
)
179169
old_project_id = data.get(
@@ -429,6 +419,9 @@ def __transform_embedding_by_name(embedding_name: str):
429419
data=embedding_tensor_item.get(
430420
"data",
431421
),
422+
sub_key=embedding_tensor_item.get(
423+
"sub_key",
424+
),
432425
)
433426

434427
def __replace_embedding_name(
@@ -1278,6 +1271,7 @@ def get_project_export_dump(
12781271
"embedding_id": str(embedding_tensor_item[0]),
12791272
"record_id": str(embedding_tensor_item[1]),
12801273
"data": embedding_tensor_item[2],
1274+
"sub_key": embedding_tensor_item[3],
12811275
}
12821276
for embedding_tensor_item in embedding_tensors
12831277
]

graphql_api/query/record.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class RecordQuery(graphene.ObjectType):
4646
embedding_id=graphene.ID(required=True),
4747
record_id=graphene.ID(required=True),
4848
att_filter=graphene.JSONString(required=False),
49+
record_sub_key=graphene.Int(required=False), # only for embedding lists
4950
)
5051

5152
tokenize_record = graphene.Field(
@@ -112,12 +113,13 @@ def resolve_search_records_by_similarity(
112113
embedding_id: str,
113114
record_id: str,
114115
att_filter: Optional[List[Dict[str, Any]]] = None,
116+
record_sub_key: Optional[int] = None,
115117
) -> ExtendedSearch:
116118
auth.check_demo_access(info)
117119
auth.check_project_access(info, project_id)
118120
user_id = auth.get_user_by_info(info).id
119121
return manager.get_records_by_similarity_search(
120-
project_id, user_id, embedding_id, record_id, att_filter
122+
project_id, user_id, embedding_id, record_id, att_filter, record_sub_key
121123
)
122124

123125
def resolve_tokenize_record(self, info, record_id: str) -> TokenizedRecord:

submodules/s3

Submodule s3 updated 1 file

0 commit comments

Comments
 (0)