Skip to content

Commit eab9a3e

Browse files
TaylorTaylor
authored andcommitted
Refactor mock retrieval functions for improved flexibility
- 🎨 Introduce `create_mock_retrieve` to parameterise mock retrieval responses. - 🔄 Remove redundant mock search functions to streamline code. - 🧪 Update tests to use the new mock retrieval function for various scenarios. - 🧹 Clean up unused mock functions to enhance maintainability.
1 parent 381a141 commit eab9a3e

File tree

3 files changed

+91
-164
lines changed

3 files changed

+91
-164
lines changed

tests/conftest.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
MockResponse,
4848
MockTransport,
4949
mock_retrieval_response,
50+
mock_retrieval_response_with_duplicates,
51+
mock_retrieval_response_with_missing_doc_key,
52+
mock_retrieval_response_with_sorting,
53+
mock_retrieval_response_with_top_limit,
5054
mock_speak_text_cancelled,
5155
mock_speak_text_failed,
5256
mock_speak_text_success,
@@ -68,13 +72,37 @@ async def mock_search(self, *args, **kwargs):
6872
return MockAsyncSearchResultsIterator(kwargs.get("search_text"), kwargs.get("vector_queries"))
6973

7074

71-
async def mock_retrieve(self, *args, **kwargs):
72-
retrieval_request = kwargs.get("retrieval_request")
73-
assert retrieval_request is not None
74-
assert retrieval_request.target_index_params is not None
75-
assert len(retrieval_request.target_index_params) == 1
76-
self.filter = retrieval_request.target_index_params[0].filter_add_on
77-
return mock_retrieval_response()
75+
def create_mock_retrieve(response_type="default"):
76+
"""Create a mock_retrieve function that returns different response types.
77+
78+
Args:
79+
response_type: Type of response to return. Options:
80+
- "default": mock_retrieval_response()
81+
- "sorting": mock_retrieval_response_with_sorting()
82+
- "duplicates": mock_retrieval_response_with_duplicates()
83+
- "missing_doc_key": mock_retrieval_response_with_missing_doc_key()
84+
- "top_limit": mock_retrieval_response_with_top_limit()
85+
"""
86+
87+
async def mock_retrieve_parameterized(self, *args, **kwargs):
88+
retrieval_request = kwargs.get("retrieval_request")
89+
assert retrieval_request is not None
90+
assert retrieval_request.target_index_params is not None
91+
assert len(retrieval_request.target_index_params) == 1
92+
self.filter = retrieval_request.target_index_params[0].filter_add_on
93+
94+
if response_type == "sorting":
95+
return mock_retrieval_response_with_sorting()
96+
elif response_type == "duplicates":
97+
return mock_retrieval_response_with_duplicates()
98+
elif response_type == "missing_doc_key":
99+
return mock_retrieval_response_with_missing_doc_key()
100+
elif response_type == "top_limit":
101+
return mock_retrieval_response_with_top_limit()
102+
else: # default
103+
return mock_retrieval_response()
104+
105+
return mock_retrieve_parameterized
78106

79107

80108
@pytest.fixture
@@ -281,7 +309,7 @@ async def mock_get_index(*args, **kwargs):
281309

282310
@pytest.fixture
283311
def mock_acs_agent(monkeypatch):
284-
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieve)
312+
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve())
285313

286314
async def mock_get_agent(*args, **kwargs):
287315
return MockAgent

tests/mocks.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -545,22 +545,6 @@ def mock_retrieval_response_with_duplicates():
545545
)
546546

547547

548-
async def mock_search_for_hydration(*args, **kwargs):
549-
"""Mock search that returns documents matching the filter"""
550-
filter_param = kwargs.get("filter", "")
551-
552-
# Create documents based on filter - use search_text to distinguish different calls
553-
search_text = ""
554-
if "doc1" in filter_param and "doc2" in filter_param:
555-
search_text = "hydrated_multi"
556-
elif "doc1" in filter_param:
557-
search_text = "hydrated_single"
558-
else:
559-
search_text = "hydrated_empty"
560-
561-
return MockAsyncSearchResultsIterator(search_text, None)
562-
563-
564548
def mock_retrieval_response_with_missing_doc_key():
565549
"""Mock response with missing doc_key to test continue condition"""
566550
return KnowledgeAgentRetrievalResponse(

tests/test_agentic_retrieval.py

Lines changed: 55 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -4,48 +4,21 @@
44
from azure.search.documents.agent.models import (
55
KnowledgeAgentAzureSearchDocReference,
66
KnowledgeAgentMessage,
7-
KnowledgeAgentMessageTextContent,
87
KnowledgeAgentRetrievalResponse,
9-
KnowledgeAgentSearchActivityRecord,
10-
KnowledgeAgentSearchActivityRecordQuery,
118
)
129
from azure.search.documents.aio import SearchClient
1310

11+
from .conftest import create_mock_retrieve
1412
from .mocks import (
1513
MockAsyncSearchResultsIterator,
16-
mock_retrieval_response,
17-
mock_retrieval_response_with_duplicates,
18-
mock_retrieval_response_with_sorting,
1914
)
2015

2116

22-
23-
24-
async def mock_search(*args, **kwargs):
25-
return MockAsyncSearchResultsIterator(kwargs.get("search_text"), kwargs.get("vector_queries"))
26-
27-
28-
async def mock_search_for_hydration(*args, **kwargs):
29-
filter_param = kwargs.get("filter", "")
30-
31-
search_text = ""
32-
if "doc1" in filter_param and "doc2" in filter_param:
33-
search_text = "hydrated_multi"
34-
elif "doc1" in filter_param:
35-
search_text = "hydrated_single"
36-
else:
37-
search_text = "hydrated_empty"
38-
39-
kwargs["search_text"] = search_text
40-
41-
return mock_search(*args, **kwargs)
42-
43-
4417
@pytest.mark.asyncio
4518
async def test_agentic_retrieval_non_hydrated_default_sort(chat_approach, monkeypatch):
4619
"""Test non-hydrated path with default sorting (preserve original order)"""
4720

48-
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval_response_with_sorting)
21+
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("sorting"))
4922

5023
agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential(""))
5124

@@ -71,12 +44,15 @@ async def test_agentic_retrieval_non_hydrated_default_sort(chat_approach, monkey
7144
async def test_agentic_retrieval_non_hydrated_interleaved_sort(chat_approach, monkeypatch):
7245
"""Test non-hydrated path with interleaved sorting"""
7346

74-
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval_response_with_sorting)
47+
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("sorting"))
7548

7649
agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential(""))
7750

7851
_, results = await chat_approach.run_agentic_retrieval(
79-
messages=[], agent_client=agent_client, search_index_name="test-index", results_merge_strategy="interleaved"
52+
messages=[],
53+
agent_client=agent_client,
54+
search_index_name="test-index",
55+
results_merge_strategy="interleaved",
8056
)
8157

8258
assert len(results) == 2
@@ -94,13 +70,21 @@ async def test_agentic_retrieval_non_hydrated_interleaved_sort(chat_approach, mo
9470
async def test_agentic_retrieval_hydrated_with_sorting(chat_approach_with_hydration, monkeypatch):
9571
"""Test hydrated path with sorting"""
9672

97-
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval_response_with_sorting)
98-
monkeypatch.setattr(SearchClient, "search", mock_search_for_hydration)
73+
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("sorting"))
74+
75+
async def mock_search(self, *args, **kwargs):
76+
# For hydration, we expect a filter like "search.in(id, 'doc1,doc2', ',')"
77+
return MockAsyncSearchResultsIterator("hydrated_multi", None)
78+
79+
monkeypatch.setattr(SearchClient, "search", mock_search)
9980

10081
agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential(""))
10182

10283
_, results = await chat_approach_with_hydration.run_agentic_retrieval(
103-
messages=[], agent_client=agent_client, search_index_name="test-index", results_merge_strategy="interleaved"
84+
messages=[],
85+
agent_client=agent_client,
86+
search_index_name="test-index",
87+
results_merge_strategy="interleaved",
10488
)
10589

10690
assert len(results) == 2
@@ -116,8 +100,13 @@ async def test_agentic_retrieval_hydrated_with_sorting(chat_approach_with_hydrat
116100
async def test_hydrate_agent_references_deduplication(chat_approach_with_hydration, monkeypatch):
117101
"""Test that hydrate_agent_references deduplicates doc_keys"""
118102

119-
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval_response_with_duplicates)
120-
monkeypatch.setattr(SearchClient, "search", mock_search_for_hydration)
103+
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("duplicates"))
104+
105+
async def mock_search(self, *args, **kwargs):
106+
# For deduplication test, we expect doc1 and doc2 to be in the filter
107+
return MockAsyncSearchResultsIterator("hydrated_multi", None)
108+
109+
monkeypatch.setattr(SearchClient, "search", mock_search)
121110

122111
agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential(""))
123112

@@ -138,7 +127,9 @@ async def test_agentic_retrieval_no_references(chat_approach, monkeypatch):
138127

139128
async def mock_retrieval(*args, **kwargs):
140129
return KnowledgeAgentRetrievalResponse(
141-
response=[KnowledgeAgentMessage(role="assistant", content=[])], activity=[], references=[]
130+
response=[KnowledgeAgentMessage(role="assistant", content=[])],
131+
activity=[],
132+
references=[],
142133
)
143134

144135
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval)
@@ -156,10 +147,7 @@ async def mock_retrieval(*args, **kwargs):
156147
async def test_activity_mapping_injection(chat_approach, monkeypatch):
157148
"""Test that search_agent_query is properly injected from activity mapping"""
158149

159-
async def mock_retrieval(*args, **kwargs):
160-
return mock_retrieval_response_with_sorting()
161-
162-
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval)
150+
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("sorting"))
163151

164152
agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential(""))
165153

@@ -178,93 +166,20 @@ async def mock_retrieval(*args, **kwargs):
178166
assert doc2.search_agent_query == "second query" # From activity_source=2
179167

180168

181-
def mock_retrieval_response_with_missing_doc_key():
182-
"""Mock response with missing doc_key to test continue condition"""
183-
return KnowledgeAgentRetrievalResponse(
184-
response=[
185-
KnowledgeAgentMessage(
186-
role="assistant",
187-
content=[KnowledgeAgentMessageTextContent(text="Test response")],
188-
)
189-
],
190-
activity=[
191-
KnowledgeAgentSearchActivityRecord(
192-
id=1,
193-
target_index="index",
194-
query=KnowledgeAgentSearchActivityRecordQuery(search="query"),
195-
count=10,
196-
elapsed_ms=50,
197-
),
198-
],
199-
references=[
200-
KnowledgeAgentAzureSearchDocReference(
201-
id="1",
202-
activity_source=1,
203-
doc_key=None, # Missing doc_key
204-
source_data={"content": "Content 1", "sourcepage": "page1.pdf"},
205-
),
206-
KnowledgeAgentAzureSearchDocReference(
207-
id="2",
208-
activity_source=1,
209-
doc_key="", # Empty doc_key
210-
source_data={"content": "Content 2", "sourcepage": "page2.pdf"},
211-
),
212-
KnowledgeAgentAzureSearchDocReference(
213-
id="3",
214-
activity_source=1,
215-
doc_key="doc3", # Valid doc_key
216-
source_data={"content": "Content 3", "sourcepage": "page3.pdf"},
217-
),
218-
],
219-
)
220-
221-
222-
def mock_retrieval_response_with_top_limit():
223-
"""Mock response with many references to test top limit during document building"""
224-
references = []
225-
for i in range(15): # More than any reasonable top limit
226-
references.append(
227-
KnowledgeAgentAzureSearchDocReference(
228-
id=str(i),
229-
activity_source=1,
230-
doc_key=f"doc{i}",
231-
source_data={"content": f"Content {i}", "sourcepage": f"page{i}.pdf"},
232-
)
233-
)
234-
235-
return KnowledgeAgentRetrievalResponse(
236-
response=[
237-
KnowledgeAgentMessage(
238-
role="assistant",
239-
content=[KnowledgeAgentMessageTextContent(text="Test response")],
240-
)
241-
],
242-
activity=[
243-
KnowledgeAgentSearchActivityRecord(
244-
id=1,
245-
target_index="index",
246-
query=KnowledgeAgentSearchActivityRecordQuery(search="query"),
247-
count=10,
248-
elapsed_ms=50,
249-
),
250-
],
251-
references=references,
252-
)
253-
254-
255169
@pytest.mark.asyncio
256170
async def test_hydrate_agent_references_missing_doc_keys(chat_approach_with_hydration, monkeypatch):
257171
"""Test that hydrate_agent_references handles missing/empty doc_keys correctly"""
258172

259-
async def mock_retrieval(*args, **kwargs):
260-
return mock_retrieval_response_with_missing_doc_key()
173+
monkeypatch.setattr(
174+
KnowledgeAgentRetrievalClient,
175+
"retrieve",
176+
create_mock_retrieve("missing_doc_key"),
177+
)
261178

262-
# Mock search to return single document for doc3
263-
async def mock_search_single(*args, **kwargs):
179+
async def mock_search(self, *args, **kwargs):
264180
return MockAsyncSearchResultsIterator("hydrated_single", None)
265181

266-
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval)
267-
monkeypatch.setattr(SearchClient, "search", mock_search_single)
182+
monkeypatch.setattr(SearchClient, "search", mock_search)
268183

269184
agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential(""))
270185

@@ -327,12 +242,12 @@ async def mock_retrieval_valid_keys(*args, **kwargs):
327242
],
328243
)
329244

330-
# Mock search to return empty results (no documents found)
331-
async def mock_search_returns_empty(*args, **kwargs):
245+
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval_valid_keys)
246+
247+
async def mock_search(self, *args, **kwargs):
332248
return MockAsyncSearchResultsIterator("hydrated_empty", None)
333249

334-
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval_valid_keys)
335-
monkeypatch.setattr(SearchClient, "search", mock_search_returns_empty)
250+
monkeypatch.setattr(SearchClient, "search", mock_search)
336251

337252
agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential(""))
338253

@@ -346,18 +261,18 @@ async def mock_search_returns_empty(*args, **kwargs):
346261

347262

348263
@pytest.mark.asyncio
349-
async def test_agentic_retrieval_with_top_limit_during_building(chat_approach_with_hydration, monkeypatch):
350-
"""Test that document building respects top limit and breaks early"""
264+
async def test_agentic_retrieval_with_top_limit_during_building(chat_approach, monkeypatch):
265+
"""Test that document building respects top limit and breaks early (non-hydrated path)"""
351266

352-
async def mock_retrieval(*args, **kwargs):
353-
return mock_retrieval_response_with_top_limit()
354-
355-
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval)
267+
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("top_limit"))
356268

357269
agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential(""))
358270

359-
_, results = await chat_approach_with_hydration.run_agentic_retrieval(
360-
messages=[], agent_client=agent_client, search_index_name="test-index", top=5 # Limit to 5 documents
271+
_, results = await chat_approach.run_agentic_retrieval(
272+
messages=[],
273+
agent_client=agent_client,
274+
search_index_name="test-index",
275+
top=5, # Limit to 5 documents
361276
)
362277

363278
# Should get exactly 5 documents due to top limit during building
@@ -371,20 +286,20 @@ async def mock_retrieval(*args, **kwargs):
371286
async def test_hydrate_agent_references_with_top_limit_during_collection(chat_approach_with_hydration, monkeypatch):
372287
"""Test that hydration respects top limit when collecting doc_keys"""
373288

374-
async def mock_retrieval(*args, **kwargs):
375-
return mock_retrieval_response_with_top_limit()
289+
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", create_mock_retrieve("top_limit"))
376290

377-
# Mock search to return multi results (more than our top limit)
378-
async def mock_search_multi(*args, **kwargs):
291+
async def mock_search(self, *args, **kwargs):
379292
return MockAsyncSearchResultsIterator("hydrated_multi", None)
380293

381-
monkeypatch.setattr(KnowledgeAgentRetrievalClient, "retrieve", mock_retrieval)
382-
monkeypatch.setattr(SearchClient, "search", mock_search_multi)
294+
monkeypatch.setattr(SearchClient, "search", mock_search)
383295

384296
agent_client = KnowledgeAgentRetrievalClient(endpoint="", agent_name="", credential=AzureKeyCredential(""))
385297

386298
_, results = await chat_approach_with_hydration.run_agentic_retrieval(
387-
messages=[], agent_client=agent_client, search_index_name="test-index", top=2 # Limit to 2 documents
299+
messages=[],
300+
agent_client=agent_client,
301+
search_index_name="test-index",
302+
top=2, # Limit to 2 documents
388303
)
389304

390305
# Should get exactly 2 documents due to top limit during doc_keys collection

0 commit comments

Comments
 (0)