diff --git a/app/backend/app.py b/app/backend/app.py index 62707e0cd7..d391e8b779 100644 --- a/app/backend/app.py +++ b/app/backend/app.py @@ -471,6 +471,7 @@ async def setup_clients(): USE_CHAT_HISTORY_BROWSER = os.getenv("USE_CHAT_HISTORY_BROWSER", "").lower() == "true" USE_CHAT_HISTORY_COSMOS = os.getenv("USE_CHAT_HISTORY_COSMOS", "").lower() == "true" USE_AGENTIC_RETRIEVAL = os.getenv("USE_AGENTIC_RETRIEVAL", "").lower() == "true" + ENABLE_AGENTIC_RETRIEVAL_SOURCE_DATA = os.getenv("ENABLE_AGENTIC_RETRIEVAL_SOURCE_DATA", "").lower() == "true" # WEBSITE_HOSTNAME is always set by App Service, RUNNING_IN_PRODUCTION is set in main.bicep RUNNING_ON_AZURE = os.getenv("WEBSITE_HOSTNAME") is not None or os.getenv("RUNNING_IN_PRODUCTION") is not None @@ -689,6 +690,7 @@ async def setup_clients(): query_speller=AZURE_SEARCH_QUERY_SPELLER, prompt_manager=prompt_manager, reasoning_effort=OPENAI_REASONING_EFFORT, + hydrate_references=ENABLE_AGENTIC_RETRIEVAL_SOURCE_DATA, multimodal_enabled=USE_MULTIMODAL, image_embeddings_client=image_embeddings_client, global_blob_manager=global_blob_manager, @@ -716,6 +718,7 @@ async def setup_clients(): query_speller=AZURE_SEARCH_QUERY_SPELLER, prompt_manager=prompt_manager, reasoning_effort=OPENAI_REASONING_EFFORT, + hydrate_references=ENABLE_AGENTIC_RETRIEVAL_SOURCE_DATA, multimodal_enabled=USE_MULTIMODAL, image_embeddings_client=image_embeddings_client, global_blob_manager=global_blob_manager, diff --git a/app/backend/approaches/approach.py b/app/backend/approaches/approach.py index 22dd3043f5..1e84417144 100644 --- a/app/backend/approaches/approach.py +++ b/app/backend/approaches/approach.py @@ -162,6 +162,7 @@ def __init__( openai_host: str, prompt_manager: PromptManager, reasoning_effort: Optional[str] = None, + hydrate_references: bool = False, multimodal_enabled: bool = False, image_embeddings_client: Optional[ImageEmbeddings] = None, global_blob_manager: Optional[BlobManager] = None, @@ -179,6 +180,7 @@ def __init__( self.openai_host = openai_host self.prompt_manager = prompt_manager self.reasoning_effort = reasoning_effort + self.hydrate_references = hydrate_references self.include_token_usage = True self.multimodal_enabled = multimodal_enabled self.image_embeddings_client = image_embeddings_client @@ -236,7 +238,7 @@ async def search( vector_queries=search_vectors, ) - documents = [] + documents: list[Document] = [] async for page in results.by_page(): async for document in page: documents.append( @@ -299,40 +301,112 @@ async def run_agentic_retrieval( ) ) - # STEP 2: Generate a contextual and content specific answer using the search results and chat history + # Map activity id -> agent's internal search query activities = response.activity - activity_mapping = ( + activity_mapping: dict[int, str] = ( { - activity.id: activity.query.search if activity.query else "" + activity.id: activity.query.search for activity in activities - if isinstance(activity, KnowledgeAgentSearchActivityRecord) + if ( + isinstance(activity, KnowledgeAgentSearchActivityRecord) + and activity.query + and activity.query.search is not None + ) } if activities else {} ) - results = [] - if response and response.references: - if results_merge_strategy == "interleaved": - # Use interleaved reference order - references = sorted(response.references, key=lambda reference: int(reference.id)) - else: - # Default to descending strategy - references = response.references - for reference in references: - if isinstance(reference, KnowledgeAgentAzureSearchDocReference) and reference.source_data: - results.append( + # No refs? we're done + if not (response and response.references): + return response, [] + + # Extract references + refs = [r for r in response.references if isinstance(r, KnowledgeAgentAzureSearchDocReference)] + + documents: list[Document] = [] + + if self.hydrate_references: + # Hydrate references to get full documents + documents = await self.hydrate_agent_references( + references=refs, + top=top, + ) + else: + # Create documents from reference source data + for ref in refs: + if ref.source_data: + documents.append( Document( - id=reference.doc_key, - content=reference.source_data["content"], - sourcepage=reference.source_data["sourcepage"], - search_agent_query=activity_mapping[reference.activity_source], + id=ref.doc_key, + content=ref.source_data.get("content"), + sourcepage=ref.source_data.get("sourcepage"), ) ) - if top and len(results) == top: - break + if top and len(documents) >= top: + break + + # Build mappings for agent queries and sorting + ref_to_activity: dict[str, int] = {} + doc_to_ref_id: dict[str, str] = {} + for ref in refs: + if ref.doc_key: + ref_to_activity[ref.doc_key] = ref.activity_source + doc_to_ref_id[ref.doc_key] = ref.id + + # Inject agent search queries into all documents + for doc in documents: + if doc.id and doc.id in ref_to_activity: + activity_id = ref_to_activity[doc.id] + doc.search_agent_query = activity_mapping.get(activity_id, "") + + # Apply sorting strategy to the documents + if results_merge_strategy == "interleaved": # Use interleaved reference order + documents = sorted( + documents, + key=lambda d: int(doc_to_ref_id.get(d.id, 0)) if d.id and doc_to_ref_id.get(d.id) else 0, + ) + # else: Default - preserve original order + + return response, documents + + async def hydrate_agent_references( + self, + references: list[KnowledgeAgentAzureSearchDocReference], + top: Optional[int], + ) -> list[Document]: + doc_keys: set[str] = set() + + for ref in references: + if not ref.doc_key: + continue + doc_keys.add(ref.doc_key) + if top and len(doc_keys) >= top: + break + + if not doc_keys: + return [] + + # Build search filter only on unique doc IDs + id_csv = ",".join(doc_keys) + id_filter = f"search.in(id, '{id_csv}', ',')" + + # Fetch full documents + hydrated_docs: list[Document] = await self.search( + top=len(doc_keys), + query_text=None, + filter=id_filter, + vectors=[], + use_text_search=False, + use_vector_search=False, + use_semantic_ranker=False, + use_semantic_captions=False, + minimum_search_score=None, + minimum_reranker_score=None, + use_query_rewriting=False, + ) - return response, results + return hydrated_docs async def get_sources_content( self, diff --git a/app/backend/approaches/chatreadretrieveread.py b/app/backend/approaches/chatreadretrieveread.py index aa725e5333..604f7388df 100644 --- a/app/backend/approaches/chatreadretrieveread.py +++ b/app/backend/approaches/chatreadretrieveread.py @@ -57,6 +57,7 @@ def __init__( query_speller: str, prompt_manager: PromptManager, reasoning_effort: Optional[str] = None, + hydrate_references: bool = False, multimodal_enabled: bool = False, image_embeddings_client: Optional[ImageEmbeddings] = None, global_blob_manager: Optional[BlobManager] = None, @@ -84,6 +85,7 @@ def __init__( self.query_rewrite_tools = self.prompt_manager.load_tools("chat_query_rewrite_tools.json") self.answer_prompt = self.prompt_manager.load_prompt("chat_answer_question.prompty") self.reasoning_effort = reasoning_effort + self.hydrate_references = hydrate_references self.include_token_usage = True self.multimodal_enabled = multimodal_enabled self.image_embeddings_client = image_embeddings_client diff --git a/app/backend/approaches/retrievethenread.py b/app/backend/approaches/retrievethenread.py index 334065c992..0fb9834c8d 100644 --- a/app/backend/approaches/retrievethenread.py +++ b/app/backend/approaches/retrievethenread.py @@ -46,6 +46,7 @@ def __init__( query_speller: str, prompt_manager: PromptManager, reasoning_effort: Optional[str] = None, + hydrate_references: bool = False, multimodal_enabled: bool = False, image_embeddings_client: Optional[ImageEmbeddings] = None, global_blob_manager: Optional[BlobManager] = None, @@ -73,6 +74,7 @@ def __init__( self.answer_prompt = self.prompt_manager.load_prompt("ask_answer_question.prompty") self.reasoning_effort = reasoning_effort self.include_token_usage = True + self.hydrate_references = hydrate_references self.multimodal_enabled = multimodal_enabled self.image_embeddings_client = image_embeddings_client self.global_blob_manager = global_blob_manager diff --git a/docs/agentic_retrieval.md b/docs/agentic_retrieval.md index 0f8f8e69a9..baa55994c1 100644 --- a/docs/agentic_retrieval.md +++ b/docs/agentic_retrieval.md @@ -34,21 +34,33 @@ See the agentic retrieval documentation. azd env set AZURE_OPENAI_SEARCHAGENT_MODEL_VERSION 2025-04-14 ``` -3. **Update the infrastructure and application:** +3. **(Optional) Enable extra field hydration** + + By default, agentic retrieval only returns fields included in the semantic configuration. + + You can enable this optional feature below, to include all fields from the search index in the result. + ⚠️ This feature is currently only compatible with indexes set up with integrated vectorization, + or indexes that otherwise have an "id" field marked as filterable. + + ```shell + azd env set ENABLE_AGENTIC_RETRIEVAL_SOURCE_DATA true + ``` + +4. **Update the infrastructure and application:** Execute `azd up` to provision the infrastructure changes (only the new model, if you ran `up` previously) and deploy the application code with the updated environment variables. -4. **Try out the feature:** +5. **Try out the feature:** Open the web app and start a new chat. Agentic retrieval will be used to find all sources. -5. **Experiment with max subqueries:** +6. **Experiment with max subqueries:** Select the developer options in the web app and change max subqueries to any value between 1 and 20. This controls the maximum amount of subqueries that can be created in the query plan. ![Max subqueries screenshot](./images/max-subqueries.png) -6. **Review the query plan** +7. **Review the query plan** Agentic retrieval use additional billed tokens behind the scenes for the planning process. To see the token usage, select the lightbulb icon on a chat answer. This will open the "Thought process" tab, which shows the amount of tokens used by and the queries produced by the planning process diff --git a/infra/main.bicep b/infra/main.bicep index 33e28cc8ea..6964b9dc75 100644 --- a/infra/main.bicep +++ b/infra/main.bicep @@ -41,6 +41,7 @@ param storageSkuName string // Set in main.parameters.json param defaultReasoningEffort string // Set in main.parameters.json param useAgenticRetrieval bool // Set in main.parameters.json +param enableAgenticRetrievalSourceData bool // Set in main.parameters.json param userStorageAccountName string = '' param userStorageContainerName string = 'user-content' @@ -423,6 +424,7 @@ var appEnvVariables = { USE_SPEECH_OUTPUT_BROWSER: useSpeechOutputBrowser USE_SPEECH_OUTPUT_AZURE: useSpeechOutputAzure USE_AGENTIC_RETRIEVAL: useAgenticRetrieval + ENABLE_AGENTIC_RETRIEVAL_SOURCE_DATA: enableAgenticRetrievalSourceData // Chat history settings USE_CHAT_HISTORY_BROWSER: useChatHistoryBrowser USE_CHAT_HISTORY_COSMOS: useChatHistoryCosmos diff --git a/infra/main.parameters.json b/infra/main.parameters.json index 8a0f762196..dd047dc56f 100644 --- a/infra/main.parameters.json +++ b/infra/main.parameters.json @@ -104,79 +104,79 @@ "backendServiceName": { "value": "${AZURE_APP_SERVICE}" }, - "chatGptModelName":{ + "chatGptModelName": { "value": "${AZURE_OPENAI_CHATGPT_MODEL}" }, "chatGptDeploymentName": { "value": "${AZURE_OPENAI_CHATGPT_DEPLOYMENT}" }, - "chatGptDeploymentVersion":{ + "chatGptDeploymentVersion": { "value": "${AZURE_OPENAI_CHATGPT_DEPLOYMENT_VERSION}" }, - "chatGptDeploymentSkuName":{ + "chatGptDeploymentSkuName": { "value": "${AZURE_OPENAI_CHATGPT_DEPLOYMENT_SKU}" }, - "chatGptDeploymentCapacity":{ + "chatGptDeploymentCapacity": { "value": "${AZURE_OPENAI_CHATGPT_DEPLOYMENT_CAPACITY}" }, - "embeddingModelName":{ + "embeddingModelName": { "value": "${AZURE_OPENAI_EMB_MODEL_NAME}" }, "embeddingDeploymentName": { "value": "${AZURE_OPENAI_EMB_DEPLOYMENT}" }, - "embeddingDeploymentVersion":{ + "embeddingDeploymentVersion": { "value": "${AZURE_OPENAI_EMB_DEPLOYMENT_VERSION}" }, - "embeddingDeploymentSkuName":{ + "embeddingDeploymentSkuName": { "value": "${AZURE_OPENAI_EMB_DEPLOYMENT_SKU}" }, - "embeddingDeploymentCapacity":{ + "embeddingDeploymentCapacity": { "value": "${AZURE_OPENAI_EMB_DEPLOYMENT_CAPACITY}" }, "embeddingDimensions": { "value": "${AZURE_OPENAI_EMB_DIMENSIONS}" }, - "evalModelName":{ + "evalModelName": { "value": "${AZURE_OPENAI_EVAL_MODEL}" }, - "evalModelVersion":{ + "evalModelVersion": { "value": "${AZURE_OPENAI_EVAL_MODEL_VERSION}" }, "evalDeploymentName": { "value": "${AZURE_OPENAI_EVAL_DEPLOYMENT}" }, - "evalDeploymentSkuName":{ + "evalDeploymentSkuName": { "value": "${AZURE_OPENAI_EVAL_DEPLOYMENT_SKU}" }, - "evalDeploymentCapacity":{ + "evalDeploymentCapacity": { "value": "${AZURE_OPENAI_EVAL_DEPLOYMENT_CAPACITY}" }, - "searchAgentModelName":{ + "searchAgentModelName": { "value": "${AZURE_OPENAI_SEARCHAGENT_MODEL}" }, - "searchAgentModelVersion":{ + "searchAgentModelVersion": { "value": "${AZURE_OPENAI_SEARCHAGENT_MODEL_VERSION}" }, "searchAgentDeploymentName": { "value": "${AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT}" }, - "searchAgentDeploymentSkuName":{ + "searchAgentDeploymentSkuName": { "value": "${AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT_SKU}" }, - "searchAgentDeploymentCapacity":{ + "searchAgentDeploymentCapacity": { "value": "${AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT_CAPACITY}" }, "openAiHost": { "value": "${OPENAI_HOST=azure}" }, - "azureOpenAiCustomUrl":{ + "azureOpenAiCustomUrl": { "value": "${AZURE_OPENAI_CUSTOM_URL}" }, - "azureOpenAiApiVersion":{ + "azureOpenAiApiVersion": { "value": "${AZURE_OPENAI_API_VERSION}" }, - "azureOpenAiApiKey":{ + "azureOpenAiApiKey": { "value": "${AZURE_OPENAI_API_KEY_OVERRIDE}" }, "azureOpenAiDisableKeys": { @@ -324,7 +324,7 @@ "value": "${DEPLOYMENT_TARGET=containerapps}" }, "webAppExists": { - "value": "${SERVICE_WEB_RESOURCE_EXISTS=false}" + "value": "${SERVICE_WEB_RESOURCE_EXISTS=false}" }, "azureContainerAppsWorkloadProfile": { "value": "${AZURE_CONTAINER_APPS_WORKLOAD_PROFILE=Consumption}" @@ -338,6 +338,9 @@ "useAgenticRetrieval": { "value": "${USE_AGENTIC_RETRIEVAL=false}" }, + "enableAgenticRetrievalSourceData": { + "value": "${ENABLE_AGENTIC_RETRIEVAL_SOURCE_DATA=false}" + }, "ragSearchTextEmbeddings": { "value": "${RAG_SEARCH_TEXT_EMBEDDINGS=true}" }, diff --git a/tests/conftest.py b/tests/conftest.py index 4f5e4aed5b..89c1c66711 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ import msal import pytest import pytest_asyncio +from azure.core.credentials import AzureKeyCredential from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient from azure.search.documents.aio import SearchClient from azure.search.documents.indexes.aio import SearchIndexClient @@ -46,6 +47,10 @@ MockResponse, MockTransport, mock_retrieval_response, + mock_retrieval_response_with_duplicates, + mock_retrieval_response_with_missing_doc_key, + mock_retrieval_response_with_sorting, + mock_retrieval_response_with_top_limit, mock_speak_text_cancelled, mock_speak_text_failed, mock_speak_text_success, @@ -67,13 +72,37 @@ async def mock_search(self, *args, **kwargs): return MockAsyncSearchResultsIterator(kwargs.get("search_text"), kwargs.get("vector_queries")) -async def mock_retrieve(self, *args, **kwargs): - retrieval_request = kwargs.get("retrieval_request") - assert retrieval_request is not None - assert retrieval_request.target_index_params is not None - assert len(retrieval_request.target_index_params) == 1 - self.filter = retrieval_request.target_index_params[0].filter_add_on - return mock_retrieval_response() +def create_mock_retrieve(response_type="default"): + """Create a mock_retrieve function that returns different response types. + + Args: + response_type: Type of response to return. Options: + - "default": mock_retrieval_response() + - "sorting": mock_retrieval_response_with_sorting() + - "duplicates": mock_retrieval_response_with_duplicates() + - "missing_doc_key": mock_retrieval_response_with_missing_doc_key() + - "top_limit": mock_retrieval_response_with_top_limit() + """ + + async def mock_retrieve_parameterized(self, *args, **kwargs): + retrieval_request = kwargs.get("retrieval_request") + assert retrieval_request is not None + assert retrieval_request.target_index_params is not None + assert len(retrieval_request.target_index_params) == 1 + self.filter = retrieval_request.target_index_params[0].filter_add_on + + if response_type == "sorting": + return mock_retrieval_response_with_sorting() + elif response_type == "duplicates": + return mock_retrieval_response_with_duplicates() + elif response_type == "missing_doc_key": + return mock_retrieval_response_with_missing_doc_key() + elif response_type == "top_limit": + return mock_retrieval_response_with_top_limit() + else: # default + return mock_retrieval_response() + + return mock_retrieve_parameterized @pytest.fixture @@ -280,7 +309,7 @@ async def mock_get_index(*args, **kwargs): @pytest.fixture def mock_acs_agent(monkeypatch): - monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieve) + monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve()) async def mock_get_agent(*args, **kwargs): return MockAgent @@ -418,6 +447,7 @@ async def mock_exists(*args, **kwargs): "AZURE_OPENAI_SEARCHAGENT_MODEL": "gpt-4.1-mini", "AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT": "gpt-4.1-mini", "USE_AGENTIC_RETRIEVAL": "true", + "ENABLE_AGENTIC_RETRIEVAL_SOURCE_DATA": "true", } ] @@ -431,6 +461,7 @@ async def mock_exists(*args, **kwargs): "AZURE_OPENAI_SEARCHAGENT_MODEL": "gpt-4.1-mini", "AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT": "gpt-4.1-mini", "USE_AGENTIC_RETRIEVAL": "true", + "ENABLE_AGENTIC_RETRIEVAL_SOURCE_DATA": "true", "AZURE_USE_AUTHENTICATION": "true", "AZURE_SERVER_APP_ID": "SERVER_APP", "AZURE_SERVER_APP_SECRET": "SECRET", @@ -1100,7 +1131,41 @@ def mock_user_directory_client(monkeypatch): @pytest.fixture def chat_approach(): return ChatReadRetrieveReadApproach( - search_client=None, + search_client=SearchClient(endpoint="", index_name="", credential=AzureKeyCredential("")), + search_index_name=None, + agent_model=None, + agent_deployment=None, + agent_client=None, + auth_helper=None, + openai_client=None, + chatgpt_model="gpt-4.1-mini", + chatgpt_deployment="chat", + embedding_deployment="embeddings", + embedding_model=MOCK_EMBEDDING_MODEL_NAME, + embedding_dimensions=MOCK_EMBEDDING_DIMENSIONS, + embedding_field="embedding3", + sourcepage_field="", + content_field="", + query_language="en-us", + query_speller="lexicon", + prompt_manager=PromptyManager(), + user_blob_manager=AdlsBlobManager( + endpoint="https://test-userstorage-account.dfs.core.windows.net", + container="test-userstorage-container", + credential=MockAzureCredential(), + ), + global_blob_manager=BlobManager( # on normal Azure storage + endpoint="https://test-globalstorage-account.blob.core.windows.net", + container="test-globalstorage-container", + credential=MockAzureCredential(), + ), + ) + + +@pytest.fixture +def chat_approach_with_hydration(): + return ChatReadRetrieveReadApproach( + search_client=SearchClient(endpoint="", index_name="", credential=AzureKeyCredential("")), search_index_name=None, agent_model=None, agent_deployment=None, @@ -1118,6 +1183,7 @@ def chat_approach(): query_language="en-us", query_speller="lexicon", prompt_manager=PromptyManager(), + hydrate_references=True, user_blob_manager=AdlsBlobManager( endpoint="https://test-userstorage-account.dfs.core.windows.net", container="test-userstorage-container", diff --git a/tests/mocks.py b/tests/mocks.py index 506bf75333..de84fa470e 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -268,6 +268,69 @@ def __init__(self, search_text, vector_queries: Optional[list[VectorQuery]]): }, ] ] + elif search_text == "hydrated": + self.data = [ + [ + { + "sourcepage": "Benefit_Options-2.pdf", + "sourcefile": "Benefit_Options.pdf", + "content": "There is a whistleblower policy.", + "embedding": [], + "category": "benefits", + "id": "Benefit_Options-2.pdf", + "@search.score": 0.03279569745063782, + "@search.reranker_score": 3.4577205181121826, + "@search.highlights": None, + "@search.captions": [MockCaption("Caption: A whistleblower policy.")], + }, + ] + ] + elif search_text == "hydrated_multi": + self.data = [ + [ + { + "id": "doc1", + "content": "Hydrated content 1", + "sourcepage": "page1.pdf", + "sourcefile": "file1.pdf", + "category": "category1", + "@search.score": 0.9, + "@search.reranker_score": 3.5, + "@search.highlights": None, + "@search.captions": [], + }, + { + "id": "doc2", + "content": "Hydrated content 2", + "sourcepage": "page2.pdf", + "sourcefile": "file2.pdf", + "category": "category2", + "@search.score": 0.8, + "@search.reranker_score": 3.2, + "@search.highlights": None, + "@search.captions": [], + }, + ] + ] + elif search_text == "hydrated_single": + self.data = [ + [ + { + "id": "doc1", + "content": "Hydrated content 1", + "sourcepage": "page1.pdf", + "sourcefile": "file1.pdf", + "category": "category1", + "@search.score": 0.9, + "@search.reranker_score": 3.5, + "@search.highlights": None, + "@search.captions": [], + }, + ] + ] + elif search_text == "hydrated_empty": + # Mock search results for empty hydration + self.data = [[]] else: self.data = [ [ @@ -392,6 +455,170 @@ def mock_retrieval_response(): ) +def mock_retrieval_response_with_sorting(): + """Mock response with multiple references for testing sorting""" + return KnowledgeAgentRetrievalResponse( + response=[ + KnowledgeAgentMessage( + role="assistant", + content=[KnowledgeAgentMessageTextContent(text="Test response")], + ) + ], + activity=[ + KnowledgeAgentSearchActivityRecord( + id=1, + target_index="index", + query=KnowledgeAgentSearchActivityRecordQuery(search="first query"), + count=10, + elapsed_ms=50, + ), + KnowledgeAgentSearchActivityRecord( + id=2, + target_index="index", + query=KnowledgeAgentSearchActivityRecordQuery(search="second query"), + count=10, + elapsed_ms=50, + ), + ], + references=[ + KnowledgeAgentAzureSearchDocReference( + id="2", # Higher ID for testing interleaved sorting + activity_source=2, + doc_key="doc2", + source_data={"content": "Content 2", "sourcepage": "page2.pdf"}, + ), + KnowledgeAgentAzureSearchDocReference( + id="1", # Lower ID for testing interleaved sorting + activity_source=1, + doc_key="doc1", + source_data={"content": "Content 1", "sourcepage": "page1.pdf"}, + ), + ], + ) + + +def mock_retrieval_response_with_duplicates(): + """Mock response with duplicate doc_keys for testing deduplication""" + return KnowledgeAgentRetrievalResponse( + response=[ + KnowledgeAgentMessage( + role="assistant", + content=[KnowledgeAgentMessageTextContent(text="Test response")], + ) + ], + activity=[ + KnowledgeAgentSearchActivityRecord( + id=1, + target_index="index", + query=KnowledgeAgentSearchActivityRecordQuery(search="query for doc1"), + count=10, + elapsed_ms=50, + ), + KnowledgeAgentSearchActivityRecord( + id=2, + target_index="index", + query=KnowledgeAgentSearchActivityRecordQuery(search="another query for doc1"), + count=10, + elapsed_ms=50, + ), + ], + references=[ + KnowledgeAgentAzureSearchDocReference( + id="1", + activity_source=1, + doc_key="doc1", # Same doc_key + source_data={"content": "Content 1", "sourcepage": "page1.pdf"}, + ), + KnowledgeAgentAzureSearchDocReference( + id="2", + activity_source=2, + doc_key="doc1", # Duplicate doc_key + source_data={"content": "Content 1", "sourcepage": "page1.pdf"}, + ), + KnowledgeAgentAzureSearchDocReference( + id="3", + activity_source=1, + doc_key="doc2", # Different doc_key + source_data={"content": "Content 2", "sourcepage": "page2.pdf"}, + ), + ], + ) + + +def mock_retrieval_response_with_missing_doc_key(): + """Mock response with missing doc_key to test continue condition""" + return KnowledgeAgentRetrievalResponse( + response=[ + KnowledgeAgentMessage( + role="assistant", + content=[KnowledgeAgentMessageTextContent(text="Test response")], + ) + ], + activity=[ + KnowledgeAgentSearchActivityRecord( + id=1, + target_index="index", + query=KnowledgeAgentSearchActivityRecordQuery(search="query"), + count=10, + elapsed_ms=50, + ), + ], + references=[ + KnowledgeAgentAzureSearchDocReference( + id="1", + activity_source=1, + doc_key=None, # Missing doc_key + source_data={"content": "Content 1", "sourcepage": "page1.pdf"}, + ), + KnowledgeAgentAzureSearchDocReference( + id="2", + activity_source=1, + doc_key="", # Empty doc_key + source_data={"content": "Content 2", "sourcepage": "page2.pdf"}, + ), + KnowledgeAgentAzureSearchDocReference( + id="3", + activity_source=1, + doc_key="doc3", # Valid doc_key + source_data={"content": "Content 3", "sourcepage": "page3.pdf"}, + ), + ], + ) + + +def mock_retrieval_response_with_top_limit(): + """Mock response with many references to test top limit during document building""" + references = [] + for i in range(15): # More than any reasonable top limit + references.append( + KnowledgeAgentAzureSearchDocReference( + id=str(i), + activity_source=1, + doc_key=f"doc{i}", + source_data={"content": f"Content {i}", "sourcepage": f"page{i}.pdf"}, + ) + ) + + return KnowledgeAgentRetrievalResponse( + response=[ + KnowledgeAgentMessage( + role="assistant", + content=[KnowledgeAgentMessageTextContent(text="Test response")], + ) + ], + activity=[ + KnowledgeAgentSearchActivityRecord( + id=1, + target_index="index", + query=KnowledgeAgentSearchActivityRecordQuery(search="query"), + count=10, + elapsed_ms=50, + ), + ], + references=references, + ) + + class MockAudio: def __init__(self, audio_data): self.audio_data = audio_data diff --git a/tests/snapshots/test_app/test_ask_rtr_text_agent/agent_client0/result.json b/tests/snapshots/test_app/test_ask_rtr_text_agent/agent_client0/result.json index 99e7321f54..1fd69cb588 100644 --- a/tests/snapshots/test_app/test_ask_rtr_text_agent/agent_client0/result.json +++ b/tests/snapshots/test_app/test_ask_rtr_text_agent/agent_client0/result.json @@ -28,17 +28,23 @@ { "description": [ { - "captions": [], + "captions": [ + { + "additional_properties": {}, + "highlights": [], + "text": "Caption: A whistleblower policy." + } + ], "category": null, "content": "There is a whistleblower policy.", "groups": null, - "id": "Benefit_Options-2.pdf", + "id": "file-Benefit_Options_pdf-42656E656669745F4F7074696F6E732E706466-page-2", "images": null, "oids": null, - "reranker_score": null, - "score": null, - "search_agent_query": "whistleblower query", - "sourcefile": null, + "reranker_score": 3.4577205181121826, + "score": 0.03279569745063782, + "search_agent_query": null, + "sourcefile": "Benefit_Options.pdf", "sourcepage": "Benefit_Options-2.pdf" } ], diff --git a/tests/snapshots/test_app/test_ask_rtr_text_agent_filter/agent_auth_client0/result.json b/tests/snapshots/test_app/test_ask_rtr_text_agent_filter/agent_auth_client0/result.json index 9ff03d86e1..3193a65e9d 100644 --- a/tests/snapshots/test_app/test_ask_rtr_text_agent_filter/agent_auth_client0/result.json +++ b/tests/snapshots/test_app/test_ask_rtr_text_agent_filter/agent_auth_client0/result.json @@ -28,17 +28,23 @@ { "description": [ { - "captions": [], + "captions": [ + { + "additional_properties": {}, + "highlights": [], + "text": "Caption: A whistleblower policy." + } + ], "category": null, "content": "There is a whistleblower policy.", "groups": null, - "id": "Benefit_Options-2.pdf", + "id": "file-Benefit_Options_pdf-42656E656669745F4F7074696F6E732E706466-page-2", "images": null, "oids": null, - "reranker_score": null, - "score": null, - "search_agent_query": "whistleblower query", - "sourcefile": null, + "reranker_score": 3.4577205181121826, + "score": 0.03279569745063782, + "search_agent_query": null, + "sourcefile": "Benefit_Options.pdf", "sourcepage": "Benefit_Options-2.pdf" } ], diff --git a/tests/snapshots/test_app/test_chat_text_agent/agent_client0/result.json b/tests/snapshots/test_app/test_chat_text_agent/agent_client0/result.json index 968118b582..cbd6e457c1 100644 --- a/tests/snapshots/test_app/test_chat_text_agent/agent_client0/result.json +++ b/tests/snapshots/test_app/test_chat_text_agent/agent_client0/result.json @@ -29,17 +29,23 @@ { "description": [ { - "captions": [], + "captions": [ + { + "additional_properties": {}, + "highlights": [], + "text": "Caption: A whistleblower policy." + } + ], "category": null, "content": "There is a whistleblower policy.", "groups": null, - "id": "Benefit_Options-2.pdf", + "id": "file-Benefit_Options_pdf-42656E656669745F4F7074696F6E732E706466-page-2", "images": null, "oids": null, - "reranker_score": null, - "score": null, - "search_agent_query": "whistleblower query", - "sourcefile": null, + "reranker_score": 3.4577205181121826, + "score": 0.03279569745063782, + "search_agent_query": null, + "sourcefile": "Benefit_Options.pdf", "sourcepage": "Benefit_Options-2.pdf" } ], diff --git a/tests/snapshots/test_app/test_chat_text_filter_agent/agent_auth_client0/result.json b/tests/snapshots/test_app/test_chat_text_filter_agent/agent_auth_client0/result.json index 3c630a15e2..448297c32e 100644 --- a/tests/snapshots/test_app/test_chat_text_filter_agent/agent_auth_client0/result.json +++ b/tests/snapshots/test_app/test_chat_text_filter_agent/agent_auth_client0/result.json @@ -29,17 +29,23 @@ { "description": [ { - "captions": [], + "captions": [ + { + "additional_properties": {}, + "highlights": [], + "text": "Caption: A whistleblower policy." + } + ], "category": null, "content": "There is a whistleblower policy.", "groups": null, - "id": "Benefit_Options-2.pdf", + "id": "file-Benefit_Options_pdf-42656E656669745F4F7074696F6E732E706466-page-2", "images": null, "oids": null, - "reranker_score": null, - "score": null, - "search_agent_query": "whistleblower query", - "sourcefile": null, + "reranker_score": 3.4577205181121826, + "score": 0.03279569745063782, + "search_agent_query": null, + "sourcefile": "Benefit_Options.pdf", "sourcepage": "Benefit_Options-2.pdf" } ], diff --git a/tests/test_agentic_retrieval.py b/tests/test_agentic_retrieval.py new file mode 100644 index 0000000000..656a3fccbc --- /dev/null +++ b/tests/test_agentic_retrieval.py @@ -0,0 +1,308 @@ +import pytest +from azure.core.credentials import AzureKeyCredential +from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient +from azure.search.documents.agent.models import ( + KnowledgeAgentAzureSearchDocReference, + KnowledgeAgentMessage, + KnowledgeAgentRetrievalResponse, +) +from azure.search.documents.aio import SearchClient + +from .conftest import create_mock_retrieve +from .mocks import ( + MockAsyncSearchResultsIterator, +) + + +@pytest.mark.asyncio +async def test_agentic_retrieval_non_hydrated_default_sort(chat_approach, monkeypatch): + """Test non-hydrated path with default sorting (preserve original order)""" + + monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("sorting")) + + agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) + + _, results = await chat_approach.run_agentic_retrieval( + messages=[], + agent_client=agent_client, + search_index_name="test-index", + results_merge_strategy=None, # Default sorting + ) + + assert len(results) == 2 + # Default sorting preserves original order (doc2, doc1) + assert results[0].id == "doc2" + assert results[0].content == "Content 2" + assert results[0].search_agent_query == "second query" + + assert results[1].id == "doc1" + assert results[1].content == "Content 1" + assert results[1].search_agent_query == "first query" + + +@pytest.mark.asyncio +async def test_agentic_retrieval_non_hydrated_interleaved_sort(chat_approach, monkeypatch): + """Test non-hydrated path with interleaved sorting""" + + monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("sorting")) + + agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) + + _, results = await chat_approach.run_agentic_retrieval( + messages=[], + agent_client=agent_client, + search_index_name="test-index", + results_merge_strategy="interleaved", + ) + + assert len(results) == 2 + # Interleaved sorting orders by reference ID (1, 2) + assert results[0].id == "doc1" # ref.id = "1" + assert results[0].content == "Content 1" + assert results[0].search_agent_query == "first query" + + assert results[1].id == "doc2" # ref.id = "2" + assert results[1].content == "Content 2" + assert results[1].search_agent_query == "second query" + + +@pytest.mark.asyncio +async def test_agentic_retrieval_hydrated_with_sorting(chat_approach_with_hydration, monkeypatch): + """Test hydrated path with sorting""" + + monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("sorting")) + + async def mock_search(self, *args, **kwargs): + # For hydration, we expect a filter like "search.in(id, 'doc1,doc2', ',')" + return MockAsyncSearchResultsIterator("hydrated_multi", None) + + monkeypatch.setattr(SearchClient, "search", mock_search) + + agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) + + _, results = await chat_approach_with_hydration.run_agentic_retrieval( + messages=[], + agent_client=agent_client, + search_index_name="test-index", + results_merge_strategy="interleaved", + ) + + assert len(results) == 2 + # Should have hydrated content, not source_data content + assert results[0].content == "Hydrated content 1" + assert results[1].content == "Hydrated content 2" + # Should still have agent queries injected + assert results[0].search_agent_query == "first query" + assert results[1].search_agent_query == "second query" + + +@pytest.mark.asyncio +async def test_hydrate_agent_references_deduplication(chat_approach_with_hydration, monkeypatch): + """Test that hydrate_agent_references deduplicates doc_keys""" + + monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("duplicates")) + + async def mock_search(self, *args, **kwargs): + # For deduplication test, we expect doc1 and doc2 to be in the filter + return MockAsyncSearchResultsIterator("hydrated_multi", None) + + monkeypatch.setattr(SearchClient, "search", mock_search) + + agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) + + _, results = await chat_approach_with_hydration.run_agentic_retrieval( + messages=[], agent_client=agent_client, search_index_name="test-index" + ) + + # Should only get 2 unique documents despite 3 references (doc1 appears twice) + assert len(results) == 2 + doc_ids = [doc.id for doc in results] + assert "doc1" in doc_ids + assert "doc2" in doc_ids + + +@pytest.mark.asyncio +async def test_agentic_retrieval_no_references(chat_approach, monkeypatch): + """Test behavior when agent returns no references""" + + async def mock_retrieval(*args, **kwargs): + return KnowledgeAgentRetrievalResponse( + response=[KnowledgeAgentMessage(role="assistant", content=[])], + activity=[], + references=[], + ) + + monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval) + + agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) + + _, results = await chat_approach.run_agentic_retrieval( + messages=[], agent_client=agent_client, search_index_name="test-index" + ) + + assert len(results) == 0 + + +@pytest.mark.asyncio +async def test_activity_mapping_injection(chat_approach, monkeypatch): + """Test that search_agent_query is properly injected from activity mapping""" + + monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("sorting")) + + agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) + + _, results = await chat_approach.run_agentic_retrieval( + messages=[], agent_client=agent_client, search_index_name="test-index" + ) + + # Verify that search_agent_query is correctly mapped from activity + assert len(results) == 2 + + # Find each document and verify its query + doc1 = next(doc for doc in results if doc.id == "doc1") + doc2 = next(doc for doc in results if doc.id == "doc2") + + assert doc1.search_agent_query == "first query" # From activity_source=1 + assert doc2.search_agent_query == "second query" # From activity_source=2 + + +@pytest.mark.asyncio +async def test_hydrate_agent_references_missing_doc_keys(chat_approach_with_hydration, monkeypatch): + """Test that hydrate_agent_references handles missing/empty doc_keys correctly""" + + monkeypatch.setattr( + KnowledgeAgentRetrievalClient, + "retrieve", + create_mock_retrieve("missing_doc_key"), + ) + + async def mock_search(self, *args, **kwargs): + return MockAsyncSearchResultsIterator("hydrated_single", None) + + monkeypatch.setattr(SearchClient, "search", mock_search) + + agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) + + _, results = await chat_approach_with_hydration.run_agentic_retrieval( + messages=[], agent_client=agent_client, search_index_name="test-index" + ) + + # Should only get doc3 since doc_key was missing/empty for others + assert len(results) == 1 + assert results[0].id == "doc1" # From mock search result + assert results[0].content == "Hydrated content 1" + + +@pytest.mark.asyncio +async def test_hydrate_agent_references_empty_doc_keys(chat_approach_with_hydration, monkeypatch): + """Test that hydrate_agent_references handles case with no valid doc_keys""" + + async def mock_retrieval_no_valid_keys(*args, **kwargs): + return KnowledgeAgentRetrievalResponse( + response=[KnowledgeAgentMessage(role="assistant", content=[])], + activity=[], + references=[ + KnowledgeAgentAzureSearchDocReference( + id="1", + activity_source=1, + doc_key=None, # No valid doc_key + source_data={"content": "Content 1", "sourcepage": "page1.pdf"}, + ), + ], + ) + + monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval_no_valid_keys) + # No need to mock search since it should never be called + + agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) + + _, results = await chat_approach_with_hydration.run_agentic_retrieval( + messages=[], agent_client=agent_client, search_index_name="test-index" + ) + + # Should get empty results since no valid doc_keys + assert len(results) == 0 + + +@pytest.mark.asyncio +async def test_hydrate_agent_references_search_returns_empty(chat_approach_with_hydration, monkeypatch): + """Test that hydrate_agent_references handles case where search returns no results""" + + async def mock_retrieval_valid_keys(*args, **kwargs): + return KnowledgeAgentRetrievalResponse( + response=[KnowledgeAgentMessage(role="assistant", content=[])], + activity=[], + references=[ + KnowledgeAgentAzureSearchDocReference( + id="1", + activity_source=1, + doc_key="nonexistent_doc", # Valid doc_key but document doesn't exist + source_data={"content": "Content 1", "sourcepage": "page1.pdf"}, + ), + ], + ) + + monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval_valid_keys) + + async def mock_search(self, *args, **kwargs): + return MockAsyncSearchResultsIterator("hydrated_empty", None) + + monkeypatch.setattr(SearchClient, "search", mock_search) + + agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) + + _, results = await chat_approach_with_hydration.run_agentic_retrieval( + messages=[], agent_client=agent_client, search_index_name="test-index" + ) + + # When hydration is enabled but returns empty results, we should get empty list + # rather than falling back to source_data (this is the expected behavior) + assert len(results) == 0 + + +@pytest.mark.asyncio +async def test_agentic_retrieval_with_top_limit_during_building(chat_approach, monkeypatch): + """Test that document building respects top limit and breaks early (non-hydrated path)""" + + monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("top_limit")) + + agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) + + _, results = await chat_approach.run_agentic_retrieval( + messages=[], + agent_client=agent_client, + search_index_name="test-index", + top=5, # Limit to 5 documents + ) + + # Should get exactly 5 documents due to top limit during building + assert len(results) == 5 + for i, result in enumerate(results): + assert result.id == f"doc{i}" + assert result.content == f"Content {i}" + + +@pytest.mark.asyncio +async def test_hydrate_agent_references_with_top_limit_during_collection(chat_approach_with_hydration, monkeypatch): + """Test that hydration respects top limit when collecting doc_keys""" + + monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("top_limit")) + + async def mock_search(self, *args, **kwargs): + return MockAsyncSearchResultsIterator("hydrated_multi", None) + + monkeypatch.setattr(SearchClient, "search", mock_search) + + agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) + + _, results = await chat_approach_with_hydration.run_agentic_retrieval( + messages=[], + agent_client=agent_client, + search_index_name="test-index", + top=2, # Limit to 2 documents + ) + + # Should get exactly 2 documents due to top limit during doc_keys collection + assert len(results) == 2 + assert results[0].content == "Hydrated content 1" + assert results[1].content == "Hydrated content 2" diff --git a/tests/test_chatapproach.py b/tests/test_chatapproach.py index 47f98083e3..aa6145f273 100644 --- a/tests/test_chatapproach.py +++ b/tests/test_chatapproach.py @@ -2,7 +2,6 @@ import pytest from azure.core.credentials import AzureKeyCredential -from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient from azure.search.documents.aio import SearchClient from azure.search.documents.models import VectorizedQuery from openai.types.chat import ChatCompletion @@ -153,30 +152,8 @@ def test_extract_followup_questions_no_pre_content(chat_approach): ], ) async def test_search_results_filtering_by_scores( - monkeypatch, minimum_search_score, minimum_reranker_score, expected_result_count + chat_approach, monkeypatch, minimum_search_score, minimum_reranker_score, expected_result_count ): - - chat_approach = ChatReadRetrieveReadApproach( - search_client=SearchClient(endpoint="", index_name="", credential=AzureKeyCredential("")), - search_index_name=None, - agent_model=None, - agent_deployment=None, - agent_client=None, - auth_helper=None, - openai_client=None, - chatgpt_model="gpt-4.1-mini", - chatgpt_deployment="chat", - embedding_deployment="embeddings", - embedding_model=MOCK_EMBEDDING_MODEL_NAME, - embedding_dimensions=MOCK_EMBEDDING_DIMENSIONS, - embedding_field="embedding3", - sourcepage_field="", - content_field="", - query_language="en-us", - query_speller="lexicon", - prompt_manager=PromptyManager(), - ) - monkeypatch.setattr(SearchClient, "search", mock_search) filtered_results = await chat_approach.search( @@ -198,27 +175,7 @@ async def test_search_results_filtering_by_scores( @pytest.mark.asyncio -async def test_search_results_query_rewriting(monkeypatch): - chat_approach = ChatReadRetrieveReadApproach( - search_client=SearchClient(endpoint="", index_name="", credential=AzureKeyCredential("")), - search_index_name=None, - agent_model=None, - agent_deployment=None, - agent_client=None, - auth_helper=None, - openai_client=None, - chatgpt_model="gpt-35-turbo", - chatgpt_deployment="chat", - embedding_deployment="embeddings", - embedding_model=MOCK_EMBEDDING_MODEL_NAME, - embedding_dimensions=MOCK_EMBEDDING_DIMENSIONS, - embedding_field="embedding3", - sourcepage_field="", - content_field="", - query_language="en-us", - query_speller="lexicon", - prompt_manager=PromptyManager(), - ) +async def test_search_results_query_rewriting(chat_approach, monkeypatch): query_rewrites = None @@ -244,42 +201,6 @@ async def validate_qr_and_mock_search(*args, **kwargs): assert query_rewrites == "generative" -@pytest.mark.asyncio -async def test_agent_retrieval_results(monkeypatch): - chat_approach = ChatReadRetrieveReadApproach( - search_client=None, - search_index_name=None, - agent_model=None, - agent_deployment=None, - agent_client=None, - auth_helper=None, - openai_client=None, - chatgpt_model="gpt-35-turbo", - chatgpt_deployment="chat", - embedding_deployment="embeddings", - embedding_model=MOCK_EMBEDDING_MODEL_NAME, - embedding_dimensions=MOCK_EMBEDDING_DIMENSIONS, - embedding_field="embedding3", - sourcepage_field="", - content_field="", - query_language="en-us", - query_speller="lexicon", - prompt_manager=PromptyManager(), - ) - - agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential("")) - - monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval) - - _, results = await chat_approach.run_agentic_retrieval(messages=[], agent_client=agent_client, search_index_name="") - - assert len(results) == 1 - assert results[0].id == "Benefit_Options-2.pdf" - assert results[0].content == "There is a whistleblower policy." - assert results[0].sourcepage == "Benefit_Options-2.pdf" - assert results[0].search_agent_query == "whistleblower query" - - @pytest.mark.asyncio async def test_compute_multimodal_embedding(monkeypatch, chat_approach): # Create a mock for the ImageEmbeddings.create_embedding_for_text method