Skip to content

Commit 2052b5b

Browse files
Qdrant filter integrations (#142)
* Filtered attributes from config request when creating an embedding * Filter params for neural search * Query for getting unique values by attribute id * Submodules change * Update embedding payloads mutation * Endpoint for unique values changed * Removed unused code * Removed unsed code * PR comments * Import export field for filter attributes * Submodules updated * PR comments * PR comments
1 parent 15b7b94 commit 2052b5b

File tree

8 files changed

+86
-20
lines changed

8 files changed

+86
-20
lines changed

controller/embedding/connector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ def request_listing_recommended_encoders() -> Any:
1010
url = f"{BASE_URI}/classification/recommend/TEXT" # TODO does here have to be a data type?
1111
return service_requests.get_call_or_raise(url)
1212

13-
def request_embedding(
14-
project_id: str, embedding_id: str
15-
) -> Any:
13+
14+
def request_embedding(project_id: str, embedding_id: str) -> Any:
1615
url = f"{BASE_URI}/embed"
1716
data = {
1817
"project_id": str(project_id),

controller/embedding/manager.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __recreate_embedding(project_id: str, embedding_id: str) -> Embedding:
185185
model=old_embedding_item.model,
186186
platform=old_embedding_item.platform,
187187
api_token=old_embedding_item.api_token,
188+
filter_attributes=old_embedding_item.filter_attributes,
188189
additional_data=old_embedding_item.additional_data,
189190
with_commit=False,
190191
)
@@ -212,3 +213,17 @@ def __recreate_embedding(project_id: str, embedding_id: str) -> Embedding:
212213
connector.request_deleting_embedding(project_id, old_id)
213214
daemon.run(connector.request_embedding, project_id, new_embedding_item.id)
214215
return new_embedding_item
216+
217+
218+
def update_embedding_payload(
219+
project_id: str, embedding_id: str, filter_attributes: List[str]
220+
) -> None:
221+
notification.send_organization_update(
222+
project_id=project_id,
223+
message=f"upload_embedding_payload:{str(embedding_id)}:start",
224+
)
225+
embedding.update_embedding_filter_attributes(
226+
project_id, embedding_id, filter_attributes, with_commit=True
227+
)
228+
connector.request_deleting_embedding(project_id, embedding_id)
229+
connector.request_tensor_upload(project_id, embedding_id)

controller/record/manager.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
from typing import List, Dict, Any
1+
from typing import List, Dict, Any, Optional
22

33
from graphql_api.types import ExtendedSearch
44
from submodules.model import Record, Attribute
5-
from submodules.model.business_objects import general, record, user_session, embedding
5+
from submodules.model.business_objects import (
6+
record,
7+
user_session,
8+
embedding,
9+
attribute,
10+
)
611
from service.search import search
712

813
from controller.record import neural_search_connector
@@ -19,9 +24,10 @@ def get_records_by_similarity_search(
1924
user_id: str,
2025
embedding_id: str,
2126
record_id: str,
27+
att_filter: Optional[List[Dict[str, Any]]] = None,
2228
) -> ExtendedSearch:
2329
record_ids = neural_search_connector.request_most_similar_record_ids(
24-
project_id, embedding_id, record_id, 100
30+
project_id, embedding_id, record_id, 100, att_filter
2531
)
2632
if not len(record_ids):
2733
record_ids = [record_id]
@@ -102,3 +108,7 @@ def __reupload_embeddings(project_id: str) -> None:
102108
embeddings = embedding.get_finished_embeddings(project_id)
103109
for e in embeddings:
104110
embedding_manager.request_tensor_upload(project_id, str(e.id))
111+
112+
113+
def get_unique_values_by_attributes(project_id: str) -> Dict[str, List[str]]:
114+
return attribute.get_unique_values_by_attributes(project_id)
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import os
2-
from typing import List
2+
from typing import List, Optional, Dict, Any
33

44
from util import service_requests
55

66
BASE_URI = os.getenv("NEURAL_SEARCH")
77

88

99
def request_most_similar_record_ids(
10-
project_id: str, embedding_id: str, record_id: str, limit: int
10+
project_id: str,
11+
embedding_id: str,
12+
record_id: str,
13+
limit: int,
14+
att_filter: Optional[List[Dict[str, Any]]] = None,
1115
) -> List[str]:
1216
url = f"{BASE_URI}/most_similar?project_id={project_id}&embedding_id={embedding_id}&record_id={record_id}&limit={limit}"
1317

14-
# changed from get to post so we can send the filter -> however currently filter isn't part of the prototype so None
15-
result = service_requests.post_call_or_raise(url, None)
18+
result = service_requests.post_call_or_raise(url, att_filter)
1619
return result

controller/transfer/project_transfer_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def __transform_embedding_by_name(embedding_name: str):
397397
model=embedding_item.get(
398398
"model",
399399
),
400+
filter_attributes=embedding_item.get("filter_attributes"),
400401
additional_data=embedding_item.get(
401402
"additional_data",
402403
),
@@ -929,7 +930,6 @@ def __post_processing_import_threaded(
929930
f"Tokenization finished, continue with embedding handling of project {project_id}"
930931
)
931932
break
932-
933933
if not data.get(
934934
"embedding_tensors_data",
935935
):
@@ -1177,6 +1177,7 @@ def get_project_export_dump(
11771177
"finished_at": embedding_item.finished_at,
11781178
"platform": embedding_item.platform,
11791179
"model": embedding_item.model,
1180+
"filter_attributes": embedding_item.filter_attributes,
11801181
"additional_data": embedding_item.additional_data,
11811182
}
11821183
for embedding_item in embeddings

graphql_api/mutation/embedding.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, List, Optional
22
from controller.auth import manager as auth
33
from controller.embedding import manager
44
from controller.auth.manager import get_user_by_info
@@ -31,6 +31,8 @@ def mutate(self, info, project_id: str, attribute_id: str, config: Dict[str, Any
3131
api_token = config.get("apiToken")
3232
terms_text = config.get("termsText")
3333
terms_accepted = config.get("termsAccepted")
34+
filter_attributes = config.get("filterAttributes")
35+
3436
additional_data = None
3537
if config.get("base") is not None:
3638
additional_data = {
@@ -39,10 +41,6 @@ def mutate(self, info, project_id: str, attribute_id: str, config: Dict[str, Any
3941
"version": config.get("version"),
4042
}
4143

42-
# prototyping logic, this will be part of config after ui integration
43-
relevant_attribute_list = attribute_do.get_all_possible_names_for_qdrant(
44-
project_id
45-
)
4644
task_queue_manager.add_task(
4745
project_id,
4846
TaskType.EMBEDDING,
@@ -58,7 +56,7 @@ def mutate(self, info, project_id: str, attribute_id: str, config: Dict[str, Any
5856
"api_token": api_token,
5957
"terms_text": terms_text,
6058
"terms_accepted": terms_accepted,
61-
"filter_attributes": relevant_attribute_list,
59+
"filter_attributes": filter_attributes,
6260
"additional_data": additional_data,
6361
},
6462
)
@@ -85,6 +83,31 @@ def mutate(self, info, project_id: str, embedding_id: str):
8583
return DeleteEmbedding(ok=True)
8684

8785

86+
class UpdateEmbeddingPayload(graphene.Mutation):
87+
class Arguments:
88+
project_id = graphene.ID(required=True)
89+
embedding_id = graphene.ID(required=True)
90+
filter_attributes = graphene.JSONString(required=False)
91+
92+
ok = graphene.Boolean()
93+
94+
def mutate(
95+
self,
96+
info,
97+
project_id: str,
98+
embedding_id: str,
99+
filter_attributes: Optional[List[str]] = None,
100+
):
101+
auth.check_demo_access(info)
102+
auth.check_project_access(info, project_id)
103+
manager.update_embedding_payload(project_id, embedding_id, filter_attributes)
104+
notification.send_organization_update(
105+
project_id, f"embedding_updated:{embedding_id}"
106+
)
107+
return UpdateEmbeddingPayload(ok=True)
108+
109+
88110
class EmbeddingMutation(graphene.ObjectType):
89111
create_embedding = CreateEmbedding.Field()
90112
delete_embedding = DeleteEmbedding.Field()
113+
update_embedding_payload = UpdateEmbeddingPayload.Field()

graphql_api/query/record.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class RecordQuery(graphene.ObjectType):
4545
project_id=graphene.ID(required=True),
4646
embedding_id=graphene.ID(required=True),
4747
record_id=graphene.ID(required=True),
48+
att_filter=graphene.JSONString(required=False),
4849
)
4950

5051
tokenize_record = graphene.Field(
@@ -59,6 +60,10 @@ class RecordQuery(graphene.ObjectType):
5960
record_ids=graphene.List(graphene.ID, required=True),
6061
)
6162

63+
unique_values_by_attributes = graphene.Field(
64+
graphene.JSONString, project_id=graphene.ID(required=True)
65+
)
66+
6267
def resolve_all_records(self, info, project_id: str) -> List[Record]:
6368
auth.check_project_access(info, project_id)
6469
return manager.get_all_records(project_id)
@@ -101,13 +106,18 @@ def resolve_search_records_extended(
101106
)
102107

103108
def resolve_search_records_by_similarity(
104-
self, info, project_id: str, embedding_id: str, record_id: str
109+
self,
110+
info,
111+
project_id: str,
112+
embedding_id: str,
113+
record_id: str,
114+
att_filter: Optional[List[Dict[str, Any]]] = None,
105115
) -> ExtendedSearch:
106116
auth.check_demo_access(info)
107117
auth.check_project_access(info, project_id)
108118
user_id = auth.get_user_by_info(info).id
109119
return manager.get_records_by_similarity_search(
110-
project_id, user_id, embedding_id, record_id
120+
project_id, user_id, embedding_id, record_id, att_filter
111121
)
112122

113123
def resolve_tokenize_record(self, info, record_id: str) -> TokenizedRecord:
@@ -127,3 +137,8 @@ def resolve_record_comments(
127137
auth.check_project_access(info, project_id)
128138
user_id = auth.get_user_id_by_info(info)
129139
return comment_manager.get_record_comments(project_id, user_id, record_ids)
140+
141+
def resolve_unique_values_by_attributes(self, info, project_id: str) -> str:
142+
auth.check_demo_access(info)
143+
auth.check_project_access(info, project_id)
144+
return manager.get_unique_values_by_attributes(project_id)

0 commit comments

Comments
 (0)