Skip to content

Commit a647073

Browse files
cbornetmdrxyCopilot
authored
feat(standard-tests): add a property to set the name of the parameter for the number of results to return (#32443)
Not all retrievers use `k` as param name to set the number of results to return. Even in LangChain itself. Eg: https://github.com/langchain-ai/langchain/blob/bc4251b9e0074faf852c9b4e184ba681a594b03e/libs/core/langchain_core/indexing/in_memory.py#L31 So it's helpful to be able to change it for a given retriever. The change also adds hints to disable the tests if the retriever doesn't support setting the param in the constructor or in the invoke method (for instance, the `InMemoryDocumentIndex` in the link supports in the constructor but not in the invoke method). This change is backward compatible. --------- Co-authored-by: Mason Daugherty <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent e120604 commit a647073

File tree

1 file changed

+70
-18
lines changed
  • libs/standard-tests/langchain_tests/integration_tests

1 file changed

+70
-18
lines changed

libs/standard-tests/langchain_tests/integration_tests/retrievers.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,49 @@ def retriever_constructor_params(self) -> dict:
2424
@property
2525
@abstractmethod
2626
def retriever_query_example(self) -> str:
27-
"""Returns a str representing the "query" of an example retriever call."""
28-
...
27+
"""Returns a str representing the ``query`` of an example retriever call."""
28+
29+
@property
30+
def num_results_arg_name(self) -> str:
31+
"""Returns the name of the parameter for the number of results returned.
32+
33+
Usually something like ``k`` or ``top_k``."""
34+
return "k"
2935

3036
@pytest.fixture
3137
def retriever(self) -> BaseRetriever:
3238
""":private:"""
3339
return self.retriever_constructor(**self.retriever_constructor_params)
3440

3541
def test_k_constructor_param(self) -> None:
36-
"""Test that the retriever constructor accepts a k parameter, representing
42+
"""Test the number of results constructor parameter.
43+
44+
Test that the retriever constructor accepts a parameter representing
3745
the number of documents to return.
3846
47+
By default, the parameter tested is named ``k``, but it can be overridden by
48+
setting the ``num_results_arg_name`` property.
49+
50+
.. note::
51+
If the retriever doesn't support configuring the number of results returned
52+
via the constructor, this test can be skipped using a pytest ``xfail`` on
53+
the test class:
54+
55+
.. code-block:: python
56+
57+
@pytest.mark.xfail(
58+
reason="This retriever doesn't support setting "
59+
"the number of results via the constructor."
60+
)
61+
def test_k_constructor_param(self) -> None:
62+
raise NotImplementedError
63+
3964
.. dropdown:: Troubleshooting
4065
41-
If this test fails, either the retriever constructor does not accept a k
42-
parameter, or the retriever does not return the correct number of documents
43-
(`k`) when it is set.
66+
If this test fails, the retriever constructor does not accept a number
67+
of results parameter, or the retriever does not return the correct number
68+
of documents ( of the one set in ``num_results_arg_name``) when it is
69+
set.
4470
4571
For example, a retriever like
4672
@@ -52,29 +78,51 @@ def test_k_constructor_param(self) -> None:
5278
5379
"""
5480
params = {
55-
k: v for k, v in self.retriever_constructor_params.items() if k != "k"
81+
k: v
82+
for k, v in self.retriever_constructor_params.items()
83+
if k != self.num_results_arg_name
5684
}
57-
params_3 = {**params, "k": 3}
85+
params_3 = {**params, self.num_results_arg_name: 3}
5886
retriever_3 = self.retriever_constructor(**params_3)
5987
result_3 = retriever_3.invoke(self.retriever_query_example)
6088
assert len(result_3) == 3
6189
assert all(isinstance(doc, Document) for doc in result_3)
6290

63-
params_1 = {**params, "k": 1}
91+
params_1 = {**params, self.num_results_arg_name: 1}
6492
retriever_1 = self.retriever_constructor(**params_1)
6593
result_1 = retriever_1.invoke(self.retriever_query_example)
6694
assert len(result_1) == 1
6795
assert all(isinstance(doc, Document) for doc in result_1)
6896

6997
def test_invoke_with_k_kwarg(self, retriever: BaseRetriever) -> None:
70-
"""Test that the invoke method accepts a k parameter, representing the number of
71-
documents to return.
98+
"""Test the number of results parameter in ``invoke()``.
99+
100+
Test that the invoke method accepts a parameter representing
101+
the number of documents to return.
102+
103+
By default, the parameter is named ``, but it can be overridden by
104+
setting the ``num_results_arg_name`` property.
105+
106+
.. note::
107+
If the retriever doesn't support configuring the number of results returned
108+
via the invoke method, this test can be skipped using a pytest ``xfail`` on
109+
the test class:
110+
111+
.. code-block:: python
112+
113+
@pytest.mark.xfail(
114+
reason="This retriever doesn't support setting "
115+
"the number of results in the invoke method."
116+
)
117+
def test_invoke_with_k_kwarg(self) -> None:
118+
raise NotImplementedError
72119
73120
.. dropdown:: Troubleshooting
74121
75-
If this test fails, the retriever's invoke method does not accept a k
76-
parameter, or the retriever does not return the correct number of documents
77-
(`k`) when it is set.
122+
If this test fails, the retriever's invoke method does not accept a number
123+
of results parameter, or the retriever does not return the correct number
124+
of documents (``k`` of the one set in ``num_results_arg_name``) when it is
125+
set.
78126
79127
For example, a retriever like
80128
@@ -85,11 +133,15 @@ def test_invoke_with_k_kwarg(self, retriever: BaseRetriever) -> None:
85133
should return 3 documents when invoked with a query.
86134
87135
"""
88-
result_1 = retriever.invoke(self.retriever_query_example, k=1)
136+
result_1 = retriever.invoke(
137+
self.retriever_query_example, None, **{self.num_results_arg_name: 1}
138+
)
89139
assert len(result_1) == 1
90140
assert all(isinstance(doc, Document) for doc in result_1)
91141

92-
result_3 = retriever.invoke(self.retriever_query_example, k=3)
142+
result_3 = retriever.invoke(
143+
self.retriever_query_example, None, **{self.num_results_arg_name: 3}
144+
)
93145
assert len(result_3) == 3
94146
assert all(isinstance(doc, Document) for doc in result_3)
95147

@@ -100,8 +152,8 @@ def test_invoke_returns_documents(self, retriever: BaseRetriever) -> None:
100152
.. dropdown:: Troubleshooting
101153
102154
If this test fails, the retriever's invoke method does not return a list of
103-
`langchain_core.document.Document` objects. Please confirm that your
104-
`_get_relevant_documents` method returns a list of `Document` objects.
155+
``langchain_core.document.Document`` objects. Please confirm that your
156+
``_get_relevant_documents`` method returns a list of ``Document`` objects.
105157
"""
106158
result = retriever.invoke(self.retriever_query_example)
107159

0 commit comments

Comments
 (0)