Skip to content

Commit 7c9d34c

Browse files
committed
add rerank top_n unit test
here is the result : ./tests.sh unit/test_rerank.py -v -x =================================================================== test session starts =================================================================== platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.6.0 -- /home/yann/dev/yann/llama.cpp/tools/server/tests/test/bin/python3 cachedir: .pytest_cache rootdir: /home/yann/dev/yann/llama.cpp/tools/server/tests configfile: pytest.ini plugins: anyio-4.11.0 collected 16 items unit/test_rerank.py::test_rerank PASSED [ 6%] unit/test_rerank.py::test_rerank_tei_format PASSED [ 12%] unit/test_rerank.py::test_invalid_rerank_req[documents0] PASSED [ 18%] unit/test_rerank.py::test_invalid_rerank_req[None] PASSED [ 25%] unit/test_rerank.py::test_invalid_rerank_req[123] PASSED [ 31%] unit/test_rerank.py::test_invalid_rerank_req[documents3] PASSED [ 37%] unit/test_rerank.py::test_rerank_usage[Machine learning is-A machine-Learning is-19] PASSED [ 43%] unit/test_rerank.py::test_rerank_usage[Which city?-Machine learning is -Paris, capitale de la-26] PASSED [ 50%] unit/test_rerank.py::test_rerank_top_n[None-4] PASSED [ 56%] unit/test_rerank.py::test_rerank_top_n[2-2] PASSED [ 62%] unit/test_rerank.py::test_rerank_top_n[4-4] PASSED [ 68%] unit/test_rerank.py::test_rerank_top_n[99-4] PASSED [ 75%] unit/test_rerank.py::test_rerank_tei_top_n[None-4] PASSED [ 81%] unit/test_rerank.py::test_rerank_tei_top_n[2-2] PASSED [ 87%] unit/test_rerank.py::test_rerank_tei_top_n[4-4] PASSED [ 93%] unit/test_rerank.py::test_rerank_tei_top_n[99-4] PASSED [100%] =================================================================== 16 passed in 8.84s ===================================================================
1 parent 314c218 commit 7c9d34c

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

tools/server/tests/unit/test_rerank.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,45 @@ def test_rerank_usage(query, doc1, doc2, n_tokens):
102102
assert res.status_code == 200
103103
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
104104
assert res.body['usage']['prompt_tokens'] == n_tokens
105+
106+
107+
@pytest.mark.parametrize("top_n,expected_len", [
108+
(None, len(TEST_DOCUMENTS)), # no top_n parameter
109+
(2, 2),
110+
(4, 4),
111+
(99, len(TEST_DOCUMENTS)), # higher than available docs
112+
])
113+
def test_rerank_top_n(top_n, expected_len):
114+
global server
115+
server.start()
116+
data = {
117+
"query": "Machine learning is",
118+
"documents": TEST_DOCUMENTS,
119+
}
120+
if top_n is not None:
121+
data["top_n"] = top_n
122+
123+
res = server.make_request("POST", "/rerank", data=data)
124+
assert res.status_code == 200
125+
assert len(res.body["results"]) == expected_len
126+
127+
128+
@pytest.mark.parametrize("top_n,expected_len", [
129+
(None, len(TEST_DOCUMENTS)), # no top_n parameter
130+
(2, 2),
131+
(4, 4),
132+
(99, len(TEST_DOCUMENTS)), # higher than available docs
133+
])
134+
def test_rerank_tei_top_n(top_n, expected_len):
135+
global server
136+
server.start()
137+
data = {
138+
"query": "Machine learning is",
139+
"texts": TEST_DOCUMENTS,
140+
}
141+
if top_n is not None:
142+
data["top_n"] = top_n
143+
144+
res = server.make_request("POST", "/rerank", data=data)
145+
assert res.status_code == 200
146+
assert len(res.body) == expected_len

0 commit comments

Comments
 (0)