Skip to content

Commit 7143763

Browse files
committed
Address reviews
1 parent c4cc7f5 commit 7143763

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

nemo_retriever/src/nemo_retriever/rerank/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
rerank_hits
1414
Convenience function to rerank a list of LanceDB hit dicts for a single
1515
query string, using either a local ``NemotronRerankV2`` model or a remote
16-
vLLM / NIM ranking endpoint.
16+
vLLM / NIM ``/v1/ranking`` endpoint.
1717
"""
1818

1919
from .rerank import NemotronRerankActor, NemotronRerankCPUActor, NemotronRerankGPUActor, rerank_hits

nemo_retriever/src/nemo_retriever/rerank/rerank.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Remote endpoint
1313
---------------
1414
When ``invoke_url`` is set the actor/function calls a vLLM (>=0.14) or NIM
15-
server that exposes the OpenAI-compatible ranking REST API. The helper accepts
15+
server that exposes the NIM ranking REST API. The helper accepts
1616
either a fully qualified ``.../reranking`` URL or a base URL and appends
1717
``/v1/ranking`` automatically::
1818

nemo_retriever/tests/test_nemotron_rerank_v2.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def test_original_hit_keys_preserved(self):
402402

403403

404404
class TestRerankViaEndpoint:
405-
def test_posts_to_rerank_url(self):
405+
def test_posts_to_ranking_url(self):
406406
from nemo_retriever.rerank.rerank import _rerank_via_endpoint
407407

408408
mock_resp = MagicMock()
@@ -426,6 +426,8 @@ def test_posts_to_rerank_url(self):
426426
call_kwargs = mock_post.call_args
427427
assert call_kwargs[0][0] == "http://localhost:8000/v1/ranking"
428428
assert call_kwargs[1]["json"]["query"] == {"text": "What is ML?"}
429+
assert call_kwargs[1]["json"]["truncate"] == "END"
430+
assert call_kwargs[1]["json"]["passages"][0] == {"text": "Machine learning is…"}
429431
assert len(call_kwargs[1]["json"]["passages"]) == 2
430432

431433
assert scores == [0.9, 0.3]
@@ -459,7 +461,7 @@ def test_authorization_header_sent_when_api_key_provided(self):
459461
from nemo_retriever.rerank.rerank import _rerank_via_endpoint
460462

461463
mock_resp = MagicMock()
462-
mock_resp.json.return_value = {"results": [{"index": 0, "relevance_score": 1.0}]}
464+
mock_resp.json.return_value = {"rankings": [{"index": 0, "logit": 1.0}]}
463465
mock_resp.raise_for_status = MagicMock()
464466

465467
with patch("requests.post", return_value=mock_resp) as mock_post:
@@ -477,7 +479,7 @@ def test_trailing_slash_on_endpoint_normalized(self):
477479
from nemo_retriever.rerank.rerank import _rerank_via_endpoint
478480

479481
mock_resp = MagicMock()
480-
mock_resp.json.return_value = {"results": [{"index": 0, "relevance_score": 0.5}]}
482+
mock_resp.json.return_value = {"rankings": [{"index": 0, "logit": 0.5}]}
481483
mock_resp.raise_for_status = MagicMock()
482484

483485
with patch("requests.post", return_value=mock_resp) as mock_post:
@@ -490,7 +492,7 @@ def test_top_n_not_in_payload_when_not_specified(self):
490492
from nemo_retriever.rerank.rerank import _rerank_via_endpoint
491493

492494
mock_resp = MagicMock()
493-
mock_resp.json.return_value = {"results": [{"index": 0, "relevance_score": 0.5}]}
495+
mock_resp.json.return_value = {"rankings": [{"index": 0, "logit": 0.5}]}
494496
mock_resp.raise_for_status = MagicMock()
495497

496498
with patch("requests.post", return_value=mock_resp) as mock_post:
@@ -532,8 +534,8 @@ def test_actor_call_scores_dataframe(self):
532534
mock_resp = MagicMock()
533535
mock_resp.raise_for_status = MagicMock()
534536
mock_resp.json.side_effect = [
535-
{"results": [{"index": 0, "relevance_score": 0.9}]},
536-
{"results": [{"index": 0, "relevance_score": 0.4}]},
537+
{"rankings": [{"index": 0, "logit": 0.9}]},
538+
{"rankings": [{"index": 0, "logit": 0.4}]},
537539
]
538540

539541
with patch("requests.post", return_value=mock_resp):
@@ -552,8 +554,8 @@ def test_actor_call_sorts_descending_by_default(self):
552554
mock_resp = MagicMock()
553555
mock_resp.raise_for_status = MagicMock()
554556
mock_resp.json.side_effect = [
555-
{"results": [{"index": 0, "relevance_score": 0.1}]},
556-
{"results": [{"index": 0, "relevance_score": 0.9}]},
557+
{"rankings": [{"index": 0, "logit": 0.1}]},
558+
{"rankings": [{"index": 0, "logit": 0.9}]},
557559
]
558560

559561
with patch("requests.post", return_value=mock_resp):
@@ -587,7 +589,7 @@ def test_actor_custom_score_column_name(self):
587589

588590
mock_resp = MagicMock()
589591
mock_resp.raise_for_status = MagicMock()
590-
mock_resp.json.return_value = {"results": [{"index": 0, "relevance_score": 0.7}]}
592+
mock_resp.json.return_value = {"rankings": [{"index": 0, "logit": 0.7}]}
591593

592594
with patch("requests.post", return_value=mock_resp):
593595
out = actor(df)

0 commit comments

Comments
 (0)