Skip to content

Commit 9006e1e

Browse files
author
Julien Almarcha
committed
add internet search test
1 parent c2d828e commit 9006e1e

File tree

2 files changed

+50
-17
lines changed

2 files changed

+50
-17
lines changed

app/helpers/_documentmanager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
import logging
23
import time
34
import traceback
@@ -351,7 +352,7 @@ async def search(
351352
)
352353
if method == SearchMethod.MULTIAGENT:
353354
searches = await multiagents.search(
354-
self.search,
355+
partial(self.search, user_id=user_id),
355356
searches,
356357
prompt,
357358
method,

app/tests/test_multiagents.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from app.schemas.collections import CollectionVisibility
66
from app.schemas.search import SearchMethod
77
from app.utils.variables import ENDPOINT__FILES, ENDPOINT__SEARCH, ENDPOINT__COLLECTIONS
8+
from unittest.mock import patch
89

910

1011
@pytest.mark.usefixtures("client")
@@ -36,15 +37,13 @@ def test_multiagents_basic(self, client: TestClient):
3637
# Post to /search with the new collection
3738
response = client.post_without_permissions(f"/v1{ENDPOINT__SEARCH}", json=payload)
3839
assert response.status_code == 200, response.text
39-
data = response.json()
40+
data = response.json()["data"][0]
4041

4142
# Check response schema
42-
assert "answer" in data
43-
assert "choice" in data
44-
assert "choice_desc" in data
45-
assert "n_retry" in data
46-
assert "sources_refs" in data
47-
assert "sources_content" in data
43+
assert data["method"] == SearchMethod.MULTIAGENT
44+
assert data["score"] == 1.0
45+
for key in ["choice", "choice_desc", "sources_refs", "sources_content"]:
46+
assert key in data["chunk"]["metadata"]
4847

4948
def test_multiagent_rag(self, client: TestClient):
5049
"""
@@ -72,7 +71,7 @@ def test_multiagent_rag(self, client: TestClient):
7271
payload = {
7372
"prompt": "Quel est le montant maximum des actes dont la signature de la première ministre peut être déléguée ?",
7473
"collections": [collection_id],
75-
"method": "semantic",
74+
"method": SearchMethod.MULTIAGENT,
7675
"k": 3,
7776
"rff_k": 1,
7877
"score_threshold": 0.5,
@@ -82,12 +81,45 @@ def test_multiagent_rag(self, client: TestClient):
8281
}
8382
response = client.post_without_permissions(f"/v1{ENDPOINT__SEARCH}", json=payload)
8483
assert response.status_code == 200, response.text
85-
data = response.json()
86-
84+
data = response.json()["data"][0]
8785
# Check response schema
88-
assert "answer" in data
89-
assert "choice" in data
90-
assert "choice_desc" in data
91-
assert "n_retry" in data
92-
assert "sources_refs" in data
93-
assert "sources_content" in data
86+
assert data["method"] == SearchMethod.MULTIAGENT
87+
assert data["score"] == 1.0
88+
for key in ["choice", "choice_desc", "sources_refs", "sources_content"]:
89+
assert key in data["chunk"]["metadata"]
90+
91+
def test_multiagent_internet_search(self, client: TestClient):
92+
"""
93+
Test the /multiagents endpoint with internet search enabled by patching get_rank to return [4].
94+
"""
95+
# Patch get_rank to always return [4]
96+
with patch("app.utils.multiagents.get_rank", return_value=[4]):
97+
# Create a private collection for the test
98+
collection_name = f"test_collection_{str(uuid4())}"
99+
params = {"name": collection_name, "visibility": CollectionVisibility.PRIVATE}
100+
response = client.post_without_permissions(url=f"/v1{ENDPOINT__COLLECTIONS}", json=params)
101+
assert response.status_code == 201, response.text
102+
collection_id = response.json()["id"]
103+
104+
# Test the /multiagents endpoint with a prompt
105+
payload = {
106+
"prompt": "Recherchez des informations sur la réforme des retraites en France.",
107+
"collections": [collection_id],
108+
"method": SearchMethod.MULTIAGENT,
109+
"k": 3,
110+
"rff_k": 1,
111+
"score_threshold": 0.5,
112+
"max_tokens": 50,
113+
"max_tokens_intermediate": 20,
114+
"model": "albert-small",
115+
}
116+
response = client.post_without_permissions(f"/v1{ENDPOINT__SEARCH}", json=payload)
117+
assert response.status_code == 200, response.text
118+
data = response.json()["data"][0]
119+
120+
# Check response schema
121+
assert data["method"] == SearchMethod.MULTIAGENT
122+
assert data["score"] == 1.0
123+
assert data["chunk"]["metadata"]["choice"] == 4
124+
for key in ["choice", "choice_desc", "sources_refs", "sources_content"]:
125+
assert key in data["chunk"]["metadata"]

0 commit comments

Comments
 (0)