1010 embedding ,
1111 record_label_association ,
1212 record ,
13+ project ,
14+ user ,
1315)
14- from submodules .model .enums import EmbeddingPlatform , LabelSource
16+ from submodules .model .cognition_objects import group_member
17+ from submodules .model .integration_objects .helper import (
18+ REFINERY_ATTRIBUTE_ACCESS_GROUPS ,
19+ REFINERY_ATTRIBUTE_ACCESS_USERS ,
20+ )
21+ from submodules .model .enums import EmbeddingPlatform , LabelSource , UserRoles
1522
1623from .similarity_threshold import SimilarityThreshold , NO_THRESHOLD_INDICATOR
24+ import traceback
1725
1826port = int (os .environ ["QDRANT_PORT" ])
1927qdrant_client = QdrantClient (host = "qdrant" , port = port , timeout = 60 )
@@ -48,9 +56,25 @@ def most_similar_by_embedding(
4856 att_filter : Optional [List [Dict [str , Any ]]] = None ,
4957 threshold : Optional [float ] = None ,
5058 include_scores : bool = False ,
59+ user_id : Optional [str ] = None ,
5160) -> List [str ]:
5261 if not is_filter_valid_for_embedding (project_id , embedding_id , att_filter ):
5362 return []
63+ if project .check_access_management_active (project_id ):
64+ if not user_id :
65+ return []
66+ requesting_user = user .get (user_id )
67+ if not requesting_user :
68+ return []
69+ if requesting_user .role != UserRoles .ENGINEER .value :
70+ check_access = True
71+ group_members = group_member .get_by_user_id (user_id )
72+ group_ids = [str (group_member .group_id ) for group_member in group_members ]
73+ else :
74+ check_access = False
75+ else :
76+ check_access = False
77+
5478 tmp_limit = limit
5579 has_sub_key = embedding .has_sub_key (project_id , embedding_id )
5680 if has_sub_key :
@@ -66,14 +90,20 @@ def most_similar_by_embedding(
6690 elif similarity_threshold == NO_THRESHOLD_INDICATOR :
6791 similarity_threshold = None
6892 try :
93+ _filter = __build_filter (att_filter )
94+ if check_access :
95+ _filter = __add_access_management_filter (_filter , group_ids , user_id )
96+
6997 search_result = qdrant_client .search (
7098 collection_name = embedding_id ,
7199 query_vector = query_vector ,
72- query_filter = __build_filter ( att_filter ) ,
100+ query_filter = _filter ,
73101 limit = tmp_limit ,
74102 score_threshold = similarity_threshold ,
75103 )
76- except Exception :
104+ except Exception as e :
105+ print (f"Error during search in Qdrant: { e } " , flush = True )
106+ print (traceback .format_exc (), flush = True )
77107 return []
78108
79109 if include_scores :
@@ -118,39 +148,61 @@ def __is_label_filter(key: str) -> bool:
118148 return parts [0 ] == LABELS_QDRANT
119149
120150
121- def __build_filter (att_filter : List [Dict [str , Any ]]) -> models .Filter :
122- if att_filter is None or len ( att_filter ) == 0 :
151+ def __build_filter (att_filter : List [Dict [str , Any ]]) -> Optional [ models .Filter ] :
152+ if not att_filter :
123153 return None
124- must = [__build_filter_item (filter_item ) for filter_item in att_filter ]
154+ must = [__build_filter_item (item ) for item in att_filter ]
125155 return models .Filter (must = must )
126156
127157
158+ def __add_access_management_filter (
159+ base_filter : Optional [models .Filter ], group_ids : List [str ], user_id : str
160+ ) -> models .Filter :
161+ access_conditions = [
162+ models .FieldCondition (
163+ key = REFINERY_ATTRIBUTE_ACCESS_GROUPS ,
164+ match = models .MatchAny (any = group_ids ),
165+ ),
166+ models .FieldCondition (
167+ key = REFINERY_ATTRIBUTE_ACCESS_USERS ,
168+ match = models .MatchValue (value = user_id ),
169+ ),
170+ ]
171+
172+ if base_filter is None :
173+ return models .Filter (should = access_conditions )
174+
175+ return models .Filter (
176+ must = base_filter .must or [],
177+ should = access_conditions ,
178+ )
179+
180+
128181def __build_filter_item (filter_item : Dict [str , Any ]) -> models .FieldCondition :
129- if isinstance (filter_item ["value" ], list ):
130- if filter_item .get ("type" ) == "between" :
131- return models .FieldCondition (
132- key = filter_item ["key" ],
133- range = models .Range (
134- gte = filter_item ["value" ][0 ],
135- lte = filter_item ["value" ][1 ],
136- ),
137- )
138- else :
139- should = [
140- models .FieldCondition (
141- key = filter_item ["key" ], match = models .MatchValue (value = value )
142- )
143- for value in filter_item ["value" ]
144- ]
145- return models .Filter (should = should )
146- else :
182+ key = filter_item ["key" ]
183+ value = filter_item ["value" ]
184+ typ = filter_item .get ("type" )
185+
186+ # BETWEEN
187+ if isinstance (value , list ) and typ == "between" :
188+ return models .FieldCondition (
189+ key = key ,
190+ range = models .Range (gte = value [0 ], lte = value [1 ]),
191+ )
192+
193+ # IN (...)
194+ if isinstance (value , list ):
147195 return models .FieldCondition (
148- key = filter_item ["key" ],
149- match = models .MatchValue (
150- value = filter_item ["value" ],
151- ),
196+ key = key ,
197+ match = models .MatchAny (any = value ),
152198 )
153199
200+ # = single value
201+ return models .FieldCondition (
202+ key = key ,
203+ match = models .MatchValue (value = value ),
204+ )
205+
154206
155207def recreate_collection (project_id : str , embedding_id : str ) -> int :
156208 embedding_item = embedding .get (project_id , embedding_id )
0 commit comments