Skip to content

Commit d342d4c

Browse files
author
Julien Almarcha
committed
Add test with uploaded document
1 parent d9a6cfb commit d342d4c

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

app/endpoints/multiagents/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def multiagents(
137137
user_id=request.app.state.user.id,
138138
)
139139
initial_docs = [doc.chunk.content for doc in searches]
140-
initial_refs = [doc.chunk.metadata.document_name for doc in searches]
140+
initial_refs = [doc.chunk.metadata.get("document_name") for doc in searches]
141141

142142
async def go_multiagents(body, initial_docs, initial_refs, n_retry, max_retry=5, window=5):
143143
docs_tmp = initial_docs[n_retry * window : (n_retry + 1) * window]
@@ -155,7 +155,7 @@ async def go_multiagents(body, initial_docs, initial_refs, n_retry, max_retry=5,
155155
return await go_multiagents(body, initial_docs, initial_refs, n_retry=n_retry + 1, max_retry=5, window=5)
156156
elif choice in [1, 2]:
157157
pass
158-
elif choice == 4 or n_retry >= max_retry: # else ?
158+
elif choice == 4 or n_retry >= max_retry:
159159
searches = await context.documents.search(
160160
session=session,
161161
collection_ids=body.collections,
@@ -167,7 +167,9 @@ async def go_multiagents(body, initial_docs, initial_refs, n_retry, max_retry=5,
167167
user_id=request.app.state.user.id,
168168
)
169169
docs_tmp = [doc.chunk.content for doc in searches]
170-
refs_tmp = [doc.chunk.metadata.document_name for doc in searches]
170+
refs_tmp = [doc.chunk.metadata.get("document_name") for doc in searches]
171+
else:
172+
raise ValueError(f"Unknown choice: {choice}")
171173
prompts = get_prompt_teller_multi(body.prompt, docs_tmp, choice)
172174
answers = await ask_in_parallel(model, prompts, user, body.max_tokens_intermediate)
173175
prompt = PROMPT_CONCAT.format(prompt=body.prompt, answers=answers)

app/tests/test_multiagents.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
import pytest
23
from uuid import uuid4
34
from fastapi.testclient import TestClient
45
from app.schemas.collections import CollectionVisibility
5-
from app.utils.variables import ENDPOINT__MULTIAGENTS, ENDPOINT__COLLECTIONS
6+
from app.utils.variables import ENDPOINT__FILES, ENDPOINT__MULTIAGENTS, ENDPOINT__COLLECTIONS
67

78

89
@pytest.mark.usefixtures("client")
@@ -26,8 +27,6 @@ def test_multiagents_basic(self, client: TestClient):
2627
"k": 3,
2728
"rff_k": 1,
2829
"score_threshold": 0.5,
29-
"writers_model": "writers_model_example",
30-
"supervisor_model": "supervisor_model_example",
3130
"max_tokens": 50,
3231
"max_tokens_intermediate": 20,
3332
"model": "albert-small",
@@ -45,3 +44,49 @@ def test_multiagents_basic(self, client: TestClient):
4544
assert "n_retry" in data
4645
assert "sources_refs" in data
4746
assert "sources_content" in data
47+
48+
def test_multiagent_rag(self, client: TestClient):
49+
"""
50+
Test the /multiagents endpoint after uploading a PDF file,
51+
then calling the endpoint similarly to test_multiagents_basic
52+
but using a different prompt.
53+
"""
54+
# Create a private collection for the test
55+
collection_name = f"test_collection_{str(uuid4())}"
56+
params = {"name": collection_name, "visibility": CollectionVisibility.PRIVATE}
57+
response = client.post_without_permissions(url=f"/v1{ENDPOINT__COLLECTIONS}", json=params)
58+
assert response.status_code == 201, response.text
59+
collection_id = response.json()["id"]
60+
61+
# Upload pdf.pdf into the new collection
62+
file_path = "app/tests/assets/pdf.pdf"
63+
with open(file_path, "rb") as file:
64+
files = {"file": (os.path.basename(file_path), file, "application/pdf")}
65+
data = {"request": '{"collection": "%s"}' % collection_id}
66+
upload_response = client.post_without_permissions(url=f"/v1{ENDPOINT__FILES}", data=data, files=files)
67+
file.close()
68+
assert upload_response.status_code == 201, upload_response.text
69+
70+
# Now test the /multiagents endpoint with a different prompt
71+
payload = {
72+
"prompt": "Quel est le montant maximum des actes dont la signature de la première ministre peut être déléguée ?",
73+
"collections": [collection_id],
74+
"method": "semantic",
75+
"k": 3,
76+
"rff_k": 1,
77+
"score_threshold": 0.5,
78+
"max_tokens": 50,
79+
"max_tokens_intermediate": 20,
80+
"model": "albert-small",
81+
}
82+
response = client.post_with_permissions(f"/v1{ENDPOINT__MULTIAGENTS}", json=payload)
83+
assert response.status_code == 200, response.text
84+
data = response.json()
85+
86+
# Check response schema
87+
assert "answer" in data
88+
assert "choice" in data
89+
assert "choice_desc" in data
90+
assert "n_retry" in data
91+
assert "sources_refs" in data
92+
assert "sources_content" in data

0 commit comments

Comments
 (0)