Skip to content

Commit f178601

Browse files
Switch from security filter to built-in ACL enforcement (#2771)
* Use built-in ACL enforcement over security filter trimming
1 parent 256ebc9 commit f178601

File tree

41 files changed

+1072
-670
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1072
-670
lines changed

app/backend/app.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,6 @@ async def setup_clients():
437437
AZURE_TENANT_ID = os.getenv("AZURE_TENANT_ID")
438438
AZURE_USE_AUTHENTICATION = os.getenv("AZURE_USE_AUTHENTICATION", "").lower() == "true"
439439
AZURE_ENFORCE_ACCESS_CONTROL = os.getenv("AZURE_ENFORCE_ACCESS_CONTROL", "").lower() == "true"
440-
AZURE_ENABLE_GLOBAL_DOCUMENT_ACCESS = os.getenv("AZURE_ENABLE_GLOBAL_DOCUMENT_ACCESS", "").lower() == "true"
441440
AZURE_ENABLE_UNAUTHENTICATED_ACCESS = os.getenv("AZURE_ENABLE_UNAUTHENTICATED_ACCESS", "").lower() == "true"
442441
AZURE_SERVER_APP_ID = os.getenv("AZURE_SERVER_APP_ID")
443442
AZURE_SERVER_APP_SECRET = os.getenv("AZURE_SERVER_APP_SECRET")
@@ -543,8 +542,7 @@ async def setup_clients():
543542
server_app_secret=AZURE_SERVER_APP_SECRET,
544543
client_app_id=AZURE_CLIENT_APP_ID,
545544
tenant_id=AZURE_AUTH_TENANT_ID,
546-
require_access_control=AZURE_ENFORCE_ACCESS_CONTROL,
547-
enable_global_documents=AZURE_ENABLE_GLOBAL_DOCUMENT_ACCESS,
545+
enforce_access_control=AZURE_ENFORCE_ACCESS_CONTROL,
548546
enable_unauthenticated_access=AZURE_ENABLE_UNAUTHENTICATED_ACCESS,
549547
)
550548

@@ -578,6 +576,8 @@ async def setup_clients():
578576
raise ValueError(
579577
"AZURE_USERSTORAGE_ACCOUNT and AZURE_USERSTORAGE_CONTAINER must be set when USE_USER_UPLOAD is true"
580578
)
579+
if not AZURE_ENFORCE_ACCESS_CONTROL:
580+
raise ValueError("AZURE_ENFORCE_ACCESS_CONTROL must be true when USE_USER_UPLOAD is true")
581581
user_blob_manager = AdlsBlobManager(
582582
endpoint=f"https://{AZURE_USERSTORAGE_ACCOUNT}.dfs.core.windows.net",
583583
container=AZURE_USERSTORAGE_CONTAINER,
@@ -676,7 +676,6 @@ async def setup_clients():
676676
agent_deployment=AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT,
677677
agent_client=agent_client,
678678
openai_client=openai_client,
679-
auth_helper=auth_helper,
680679
chatgpt_model=OPENAI_CHATGPT_MODEL,
681680
chatgpt_deployment=AZURE_OPENAI_CHATGPT_DEPLOYMENT,
682681
embedding_model=OPENAI_EMB_MODEL,
@@ -703,7 +702,6 @@ async def setup_clients():
703702
agent_deployment=AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT,
704703
agent_client=agent_client,
705704
openai_client=openai_client,
706-
auth_helper=auth_helper,
707705
chatgpt_model=OPENAI_CHATGPT_MODEL,
708706
chatgpt_deployment=AZURE_OPENAI_CHATGPT_DEPLOYMENT,
709707
embedding_model=OPENAI_EMB_MODEL,

app/backend/approaches/approach.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
)
3333

3434
from approaches.promptmanager import PromptManager
35-
from core.authentication import AuthenticationHelper
3635
from prepdocslib.blobmanager import AdlsBlobManager, BlobManager
3736
from prepdocslib.embeddings import ImageEmbeddings
3837

@@ -152,7 +151,6 @@ def __init__(
152151
self,
153152
search_client: SearchClient,
154153
openai_client: AsyncOpenAI,
155-
auth_helper: AuthenticationHelper,
156154
query_language: Optional[str],
157155
query_speller: Optional[str],
158156
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
@@ -169,7 +167,6 @@ def __init__(
169167
):
170168
self.search_client = search_client
171169
self.openai_client = openai_client
172-
self.auth_helper = auth_helper
173170
self.query_language = query_language
174171
self.query_speller = query_speller
175172
self.embedding_deployment = embedding_deployment
@@ -185,17 +182,14 @@ def __init__(
185182
self.global_blob_manager = global_blob_manager
186183
self.user_blob_manager = user_blob_manager
187184

188-
def build_filter(self, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> Optional[str]:
185+
def build_filter(self, overrides: dict[str, Any]) -> Optional[str]:
189186
include_category = overrides.get("include_category")
190187
exclude_category = overrides.get("exclude_category")
191-
security_filter = self.auth_helper.build_security_filters(overrides, auth_claims)
192188
filters = []
193189
if include_category:
194190
filters.append("category eq '{}'".format(include_category.replace("'", "''")))
195191
if exclude_category:
196192
filters.append("category ne '{}'".format(exclude_category.replace("'", "''")))
197-
if security_filter:
198-
filters.append(security_filter)
199193
return None if len(filters) == 0 else " and ".join(filters)
200194

201195
async def search(
@@ -211,6 +205,7 @@ async def search(
211205
minimum_search_score: Optional[float] = None,
212206
minimum_reranker_score: Optional[float] = None,
213207
use_query_rewriting: Optional[bool] = None,
208+
access_token: Optional[str] = None,
214209
) -> list[Document]:
215210
search_text = query_text if use_text_search else ""
216211
search_vectors = vectors if use_vector_search else []
@@ -227,13 +222,15 @@ async def search(
227222
query_speller=self.query_speller,
228223
semantic_configuration_name="default",
229224
semantic_query=query_text,
225+
x_ms_query_source_authorization=access_token,
230226
)
231227
else:
232228
results = await self.search_client.search(
233229
search_text=search_text,
234230
filter=filter,
235231
top=top,
236232
vector_queries=search_vectors,
233+
x_ms_query_source_authorization=access_token,
237234
)
238235

239236
documents: list[Document] = []
@@ -275,6 +272,7 @@ async def run_agentic_retrieval(
275272
filter_add_on: Optional[str] = None,
276273
minimum_reranker_score: Optional[float] = None,
277274
results_merge_strategy: Optional[str] = None,
275+
access_token: Optional[str] = None,
278276
) -> tuple[KnowledgeAgentRetrievalResponse, list[Document]]:
279277
# STEP 1: Invoke agentic retrieval
280278
response = await agent_client.retrieve(
@@ -292,7 +290,8 @@ async def run_agentic_retrieval(
292290
filter_add_on=filter_add_on,
293291
)
294292
],
295-
)
293+
),
294+
x_ms_query_source_authorization=access_token,
296295
)
297296

298297
# Map activity id -> agent's internal search query

app/backend/approaches/chatreadretrieveread.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
ThoughtStep,
2121
)
2222
from approaches.promptmanager import PromptManager
23-
from core.authentication import AuthenticationHelper
2423
from prepdocslib.blobmanager import AdlsBlobManager, BlobManager
2524
from prepdocslib.embeddings import ImageEmbeddings
2625

@@ -42,7 +41,6 @@ def __init__(
4241
agent_model: Optional[str],
4342
agent_deployment: Optional[str],
4443
agent_client: KnowledgeAgentRetrievalClient,
45-
auth_helper: AuthenticationHelper,
4644
openai_client: AsyncOpenAI,
4745
chatgpt_model: str,
4846
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI
@@ -67,7 +65,6 @@ def __init__(
6765
self.agent_deployment = agent_deployment
6866
self.agent_client = agent_client
6967
self.openai_client = openai_client
70-
self.auth_helper = auth_helper
7168
self.chatgpt_model = chatgpt_model
7269
self.chatgpt_deployment = chatgpt_deployment
7370
self.embedding_deployment = embedding_deployment
@@ -279,7 +276,8 @@ async def run_search_approach(
279276
top = overrides.get("top", 3)
280277
minimum_search_score = overrides.get("minimum_search_score", 0.0)
281278
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
282-
search_index_filter = self.build_filter(overrides, auth_claims)
279+
search_index_filter = self.build_filter(overrides)
280+
access_token = auth_claims.get("access_token")
283281
send_text_sources = overrides.get("send_text_sources", True)
284282
send_image_sources = overrides.get("send_image_sources", self.multimodal_enabled) and self.multimodal_enabled
285283
search_text_embeddings = overrides.get("search_text_embeddings", True)
@@ -337,6 +335,7 @@ async def run_search_approach(
337335
minimum_search_score,
338336
minimum_reranker_score,
339337
use_query_rewriting,
338+
access_token,
340339
)
341340

342341
# STEP 3: Generate a contextual and content specific answer using the search results and chat history
@@ -388,7 +387,8 @@ async def run_agentic_retrieval_approach(
388387
overrides: dict[str, Any],
389388
auth_claims: dict[str, Any],
390389
):
391-
search_index_filter = self.build_filter(overrides, auth_claims)
390+
search_index_filter = self.build_filter(overrides)
391+
access_token = auth_claims.get("access_token")
392392
minimum_reranker_score = overrides.get("minimum_reranker_score", 0)
393393
top = overrides.get("top", 3)
394394
results_merge_strategy = overrides.get("results_merge_strategy", "interleaved")
@@ -403,6 +403,7 @@ async def run_agentic_retrieval_approach(
403403
filter_add_on=search_index_filter,
404404
minimum_reranker_score=minimum_reranker_score,
405405
results_merge_strategy=results_merge_strategy,
406+
access_token=access_token,
406407
)
407408

408409
data_points = await self.get_sources_content(

app/backend/approaches/retrievethenread.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
ThoughtStep,
1313
)
1414
from approaches.promptmanager import PromptManager
15-
from core.authentication import AuthenticationHelper
1615
from prepdocslib.blobmanager import AdlsBlobManager, BlobManager
1716
from prepdocslib.embeddings import ImageEmbeddings
1817

@@ -32,7 +31,6 @@ def __init__(
3231
agent_model: Optional[str],
3332
agent_deployment: Optional[str],
3433
agent_client: KnowledgeAgentRetrievalClient,
35-
auth_helper: AuthenticationHelper,
3634
openai_client: AsyncOpenAI,
3735
chatgpt_model: str,
3836
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI
@@ -58,7 +56,6 @@ def __init__(
5856
self.agent_client = agent_client
5957
self.chatgpt_deployment = chatgpt_deployment
6058
self.openai_client = openai_client
61-
self.auth_helper = auth_helper
6259
self.chatgpt_model = chatgpt_model
6360
self.embedding_model = embedding_model
6461
self.embedding_dimensions = embedding_dimensions
@@ -155,7 +152,8 @@ async def run_search_approach(
155152
top = overrides.get("top", 3)
156153
minimum_search_score = overrides.get("minimum_search_score", 0.0)
157154
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
158-
filter = self.build_filter(overrides, auth_claims)
155+
filter = self.build_filter(overrides)
156+
access_token = auth_claims.get("access_token")
159157
q = str(messages[-1]["content"])
160158
send_text_sources = overrides.get("send_text_sources", True)
161159
send_image_sources = overrides.get("send_image_sources", self.multimodal_enabled) and self.multimodal_enabled
@@ -183,6 +181,7 @@ async def run_search_approach(
183181
minimum_search_score,
184182
minimum_reranker_score,
185183
use_query_rewriting,
184+
access_token,
186185
)
187186

188187
data_points = await self.get_sources_content(
@@ -225,7 +224,8 @@ async def run_agentic_retrieval_approach(
225224
auth_claims: dict[str, Any],
226225
) -> ExtraInfo:
227226
minimum_reranker_score = overrides.get("minimum_reranker_score", 0)
228-
search_index_filter = self.build_filter(overrides, auth_claims)
227+
search_index_filter = self.build_filter(overrides)
228+
access_token = auth_claims.get("access_token")
229229
top = overrides.get("top", 3)
230230
results_merge_strategy = overrides.get("results_merge_strategy", "interleaved")
231231
send_text_sources = overrides.get("send_text_sources", True)
@@ -239,6 +239,7 @@ async def run_agentic_retrieval_approach(
239239
filter_add_on=search_index_filter,
240240
minimum_reranker_score=minimum_reranker_score,
241241
results_merge_strategy=results_merge_strategy,
242+
access_token=access_token,
242243
)
243244

244245
data_points = await self.get_sources_content(

0 commit comments

Comments
 (0)