Skip to content

Commit 60a0a2e

Browse files
JWittmeyerandhreljaKernLennartSchmidtKern
authored
Cognition integration provider (#87)
* chore: update submodules * Adding groups for access management (#83) * filter * engineer access * fixes * model * model * chore: update submodules --------- Co-authored-by: andhreljaKern <[email protected]> * chore: update submodules * chore: update submodules * perf: add REFINERY_ATTRIBUTE_ACCESS constants * chore: update submodules * chore: add todo comment * perf: error tracing * chore: update submodules * group id str * group id str * chore: update submodules * Fix filter with filter * PR coment * chore: update submodules * Submodules update --------- Co-authored-by: andhreljaKern <[email protected]> Co-authored-by: Lennart Schmidt <[email protected]> Co-authored-by: LennartSchmidtKern <[email protected]>
1 parent e3fefef commit 60a0a2e

File tree

3 files changed

+88
-29
lines changed

3 files changed

+88
-29
lines changed

app.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
)
1010
from submodules.model import session
1111

12+
import traceback
13+
1214
app = FastAPI()
1315

1416

@@ -17,6 +19,9 @@ async def handle_db_session(request: Request, call_next):
1719
session_token = general.get_ctx_token()
1820
try:
1921
response = await call_next(request)
22+
except Exception:
23+
print(traceback.format_exc(), flush=True)
24+
response = None
2025
finally:
2126
general.remove_and_refresh_session(session_token)
2227

@@ -66,6 +71,7 @@ class MostSimilarByEmbeddingRequest(BaseModel):
6671
att_filter: Optional[List[Dict[str, Any]]] = None
6772
threshold: Optional[Union[float, int]] = None
6873
question: Optional[str] = None
74+
user_id: Optional[str] = None
6975

7076

7177
@app.post("/most_similar_by_embedding")
@@ -99,6 +105,7 @@ def most_similar_by_embedding(
99105
request.att_filter,
100106
request.threshold,
101107
include_scores,
108+
request.user_id,
102109
)
103110

104111
if request.question:

neural_search/util.py

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,18 @@
1010
embedding,
1111
record_label_association,
1212
record,
13+
project,
14+
user,
1315
)
14-
from submodules.model.enums import EmbeddingPlatform, LabelSource
16+
from submodules.model.cognition_objects import group_member
17+
from submodules.model.integration_objects.helper import (
18+
REFINERY_ATTRIBUTE_ACCESS_GROUPS,
19+
REFINERY_ATTRIBUTE_ACCESS_USERS,
20+
)
21+
from submodules.model.enums import EmbeddingPlatform, LabelSource, UserRoles
1522

1623
from .similarity_threshold import SimilarityThreshold, NO_THRESHOLD_INDICATOR
24+
import traceback
1725

1826
port = int(os.environ["QDRANT_PORT"])
1927
qdrant_client = QdrantClient(host="qdrant", port=port, timeout=60)
@@ -48,9 +56,25 @@ def most_similar_by_embedding(
4856
att_filter: Optional[List[Dict[str, Any]]] = None,
4957
threshold: Optional[float] = None,
5058
include_scores: bool = False,
59+
user_id: Optional[str] = None,
5160
) -> List[str]:
5261
if not is_filter_valid_for_embedding(project_id, embedding_id, att_filter):
5362
return []
63+
if project.check_access_management_active(project_id):
64+
if not user_id:
65+
return []
66+
requesting_user = user.get(user_id)
67+
if not requesting_user:
68+
return []
69+
if requesting_user.role != UserRoles.ENGINEER.value:
70+
check_access = True
71+
group_members = group_member.get_by_user_id(user_id)
72+
group_ids = [str(group_member.group_id) for group_member in group_members]
73+
else:
74+
check_access = False
75+
else:
76+
check_access = False
77+
5478
tmp_limit = limit
5579
has_sub_key = embedding.has_sub_key(project_id, embedding_id)
5680
if has_sub_key:
@@ -66,14 +90,20 @@ def most_similar_by_embedding(
6690
elif similarity_threshold == NO_THRESHOLD_INDICATOR:
6791
similarity_threshold = None
6892
try:
93+
_filter = __build_filter(att_filter)
94+
if check_access:
95+
_filter = __add_access_management_filter(_filter, group_ids, user_id)
96+
6997
search_result = qdrant_client.search(
7098
collection_name=embedding_id,
7199
query_vector=query_vector,
72-
query_filter=__build_filter(att_filter),
100+
query_filter=_filter,
73101
limit=tmp_limit,
74102
score_threshold=similarity_threshold,
75103
)
76-
except Exception:
104+
except Exception as e:
105+
print(f"Error during search in Qdrant: {e}", flush=True)
106+
print(traceback.format_exc(), flush=True)
77107
return []
78108

79109
if include_scores:
@@ -118,39 +148,61 @@ def __is_label_filter(key: str) -> bool:
118148
return parts[0] == LABELS_QDRANT
119149

120150

121-
def __build_filter(att_filter: List[Dict[str, Any]]) -> models.Filter:
122-
if att_filter is None or len(att_filter) == 0:
151+
def __build_filter(att_filter: List[Dict[str, Any]]) -> Optional[models.Filter]:
152+
if not att_filter:
123153
return None
124-
must = [__build_filter_item(filter_item) for filter_item in att_filter]
154+
must = [__build_filter_item(item) for item in att_filter]
125155
return models.Filter(must=must)
126156

127157

158+
def __add_access_management_filter(
159+
base_filter: Optional[models.Filter], group_ids: List[str], user_id: str
160+
) -> models.Filter:
161+
access_conditions = [
162+
models.FieldCondition(
163+
key=REFINERY_ATTRIBUTE_ACCESS_GROUPS,
164+
match=models.MatchAny(any=group_ids),
165+
),
166+
models.FieldCondition(
167+
key=REFINERY_ATTRIBUTE_ACCESS_USERS,
168+
match=models.MatchValue(value=user_id),
169+
),
170+
]
171+
172+
if base_filter is None:
173+
return models.Filter(should=access_conditions)
174+
175+
return models.Filter(
176+
must=base_filter.must or [],
177+
should=access_conditions,
178+
)
179+
180+
128181
def __build_filter_item(filter_item: Dict[str, Any]) -> models.FieldCondition:
129-
if isinstance(filter_item["value"], list):
130-
if filter_item.get("type") == "between":
131-
return models.FieldCondition(
132-
key=filter_item["key"],
133-
range=models.Range(
134-
gte=filter_item["value"][0],
135-
lte=filter_item["value"][1],
136-
),
137-
)
138-
else:
139-
should = [
140-
models.FieldCondition(
141-
key=filter_item["key"], match=models.MatchValue(value=value)
142-
)
143-
for value in filter_item["value"]
144-
]
145-
return models.Filter(should=should)
146-
else:
182+
key = filter_item["key"]
183+
value = filter_item["value"]
184+
typ = filter_item.get("type")
185+
186+
# BETWEEN
187+
if isinstance(value, list) and typ == "between":
188+
return models.FieldCondition(
189+
key=key,
190+
range=models.Range(gte=value[0], lte=value[1]),
191+
)
192+
193+
# IN (...)
194+
if isinstance(value, list):
147195
return models.FieldCondition(
148-
key=filter_item["key"],
149-
match=models.MatchValue(
150-
value=filter_item["value"],
151-
),
196+
key=key,
197+
match=models.MatchAny(any=value),
152198
)
153199

200+
# = single value
201+
return models.FieldCondition(
202+
key=key,
203+
match=models.MatchValue(value=value),
204+
)
205+
154206

155207
def recreate_collection(project_id: str, embedding_id: str) -> int:
156208
embedding_item = embedding.get(project_id, embedding_id)

0 commit comments

Comments
 (0)