Skip to content

Commit b832950

Browse files
committed
improve handling of warnings
1 parent 209b867 commit b832950

14 files changed

+119
-34
lines changed

libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Annotated, Any, Dict, List, Optional, Union
23

34
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
@@ -46,7 +47,13 @@ def _get_relevant_documents(
4647
Returns:
4748
List of relevant documents
4849
"""
49-
default_k = self.k if self.k is not None else self.top_k
50+
is_top_k_set = False
51+
with warnings.catch_warnings():
52+
# Ignore warning raised by checking the value of top_k.
53+
warnings.simplefilter("ignore", DeprecationWarning)
54+
if self.top_k is not None:
55+
is_top_k_set = True
56+
default_k = self.k if not is_top_k_set else self.top_k
5057
pipeline = text_search_stage( # type: ignore
5158
query=query,
5259
search_field=self.search_field,

libs/langchain-mongodb/langchain_mongodb/retrievers/hybrid_search.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Annotated, Any, Dict, List, Optional
23

34
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
@@ -83,7 +84,13 @@ def _get_relevant_documents(
8384
pipeline: List[Any] = []
8485

8586
# Get the appropriate value for k.
86-
default_k = self.top_k if self.top_k is not None else self.k
87+
is_top_k_set = False
88+
with warnings.catch_warnings():
89+
# Ignore warnings raised by base class.
90+
warnings.simplefilter("ignore", DeprecationWarning)
91+
if self.top_k is not None:
92+
is_top_k_set = True
93+
default_k = self.k if not is_top_k_set else self.top_k
8794
k = kwargs.get("k", default_k)
8895

8996
# First we build up the aggregation pipeline,

libs/langchain-mongodb/tests/integration_tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23
from typing import Generator, List
34

45
import pytest
@@ -16,7 +17,10 @@
1617
def technical_report_pages() -> List[Document]:
1718
"""Returns a Document for each of the 100 pages of a GPT-4 Technical Report"""
1819
loader = PyPDFLoader("https://arxiv.org/pdf/2303.08774.pdf")
19-
pages = loader.load()
20+
with warnings.catch_warnings():
21+
# Ignore warnings raised by base class.
22+
warnings.simplefilter("ignore", ResourceWarning)
23+
pages = loader.load()
2024
return pages
2125

2226

libs/langchain-mongodb/tests/integration_tests/test_cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def test_mongodb_cache(
157157
_execute_test(prompt, llm, response)
158158
finally:
159159
get_llm_cache().clear()
160+
get_llm_cache().close()
160161

161162

162163
@pytest.mark.parametrize(
@@ -208,3 +209,4 @@ def test_mongodb_atlas_cache_matrix(
208209
generations=llm_generations, llm_output={}
209210
)
210211
get_llm_cache().clear()
212+
get_llm_cache().close()

libs/langchain-mongodb/tests/integration_tests/test_chat_message_histories.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import warnings
23

34
from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
45
from langchain_core.messages import message_to_dict
@@ -19,9 +20,12 @@ def test_memory_with_message_store() -> None:
1920
database_name=DB_NAME,
2021
collection_name=COLLECTION,
2122
)
22-
memory = ConversationBufferMemory(
23-
memory_key="baz", chat_memory=message_history, return_messages=True
24-
)
23+
with warnings.catch_warnings():
24+
# Ignore warnings raised by base class.
25+
warnings.simplefilter("ignore", DeprecationWarning)
26+
memory = ConversationBufferMemory(
27+
memory_key="baz", chat_memory=message_history, return_messages=True
28+
)
2529

2630
# add some messages
2731
memory.chat_memory.add_ai_message("This is me, the AI")
@@ -38,3 +42,4 @@ def test_memory_with_message_store() -> None:
3842
memory.chat_memory.clear()
3943

4044
assert memory.chat_memory.messages == []
45+
memory.chat_memory.close()

libs/langchain-mongodb/tests/integration_tests/test_parent_document.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,4 @@ def test_1clxn_retriever(
7474
assert len(responses) == 3
7575
assert all("GPT-4" in doc.page_content for doc in responses)
7676
assert {4, 5, 29} == set(doc.metadata["page"] for doc in responses)
77+
client.close()

libs/langchain-mongodb/tests/integration_tests/test_retrievers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,14 @@ def test_hybrid_retriever_deprecated_top_k(
194194
)
195195

196196
query1 = "When did I visit France?"
197-
results = retriever.invoke(query1)
197+
with pytest.warns(DeprecationWarning):
198+
results = retriever.invoke(query1)
198199
assert len(results) == 3
199200
assert "Paris" in results[0].page_content
200201

201202
query2 = "When was the last time I visited new orleans?"
202-
results = retriever.invoke(query2)
203+
with pytest.warns(DeprecationWarning):
204+
results = retriever.invoke(query2)
203205
assert "New Orleans" in results[0].page_content
204206

205207

libs/langchain-mongodb/tests/integration_tests/test_retrievers_multi_field.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,14 @@ def test_hybrid_retriever_deprecated_top_k(
207207
)
208208

209209
query1 = "When did I visit France?"
210-
results = retriever.invoke(query1)
210+
with pytest.warns(DeprecationWarning):
211+
results = retriever.invoke(query1)
211212
assert len(results) == 3
212213
assert "Paris" in results[0].page_content
213214

214215
query2 = "When was the last time I visited new orleans?"
215-
results = retriever.invoke(query2)
216+
with pytest.warns(DeprecationWarning):
217+
results = retriever.invoke(query2)
216218
assert "New Orleans" in results[0].page_content
217219

218220

libs/langchain-mongodb/tests/integration_tests/test_retrievers_standard.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,16 @@ def setup_test() -> tuple[Collection, MongoDBAtlasVectorSearch]:
4444
text_key=PAGE_CONTENT_FIELD,
4545
auto_index_timeout=TIMEOUT,
4646
)
47-
48-
if coll.count_documents({}) == 0:
49-
vs.add_documents(
50-
[
51-
Document(page_content="In 2023, I visited Paris"),
52-
Document(page_content="In 2022, I visited New York"),
53-
Document(page_content="In 2021, I visited New Orleans"),
54-
Document(page_content="Sandwiches are beautiful. Sandwiches are fine."),
55-
]
56-
)
47+
coll.delete_many({})
48+
49+
vs.add_documents(
50+
[
51+
Document(page_content="In 2023, I visited Paris"),
52+
Document(page_content="In 2022, I visited New York"),
53+
Document(page_content="In 2021, I visited New Orleans"),
54+
Document(page_content="Sandwiches are beautiful. Sandwiches are fine."),
55+
]
56+
)
5757

5858
# Set up the search index if needed.
5959
if not any([ix["name"] == SEARCH_INDEX_NAME for ix in coll.list_search_indexes()]):
@@ -68,16 +68,23 @@ def setup_test() -> tuple[Collection, MongoDBAtlasVectorSearch]:
6868

6969

7070
class TestMongoDBAtlasFullTextSearchRetriever(RetrieversIntegrationTests):
71+
@classmethod
72+
def setup_class(cls):
73+
cls._coll, _ = setup_test()
74+
75+
@classmethod
76+
def teardown_class(cls):
77+
cls._coll.database.client.close()
78+
7179
@property
7280
def retriever_constructor(self) -> Type[MongoDBAtlasFullTextSearchRetriever]:
7381
"""Get a retriever for integration tests."""
7482
return MongoDBAtlasFullTextSearchRetriever
7583

7684
@property
7785
def retriever_constructor_params(self) -> dict:
78-
coll, _ = setup_test()
7986
return {
80-
"collection": coll,
87+
"collection": self._coll,
8188
"search_index_name": SEARCH_INDEX_NAME,
8289
"search_field": PAGE_CONTENT_FIELD,
8390
}
@@ -91,17 +98,24 @@ def retriever_query_example(self) -> str:
9198

9299

93100
class TestMongoDBAtlasHybridSearchRetriever(RetrieversIntegrationTests):
101+
@classmethod
102+
def setup_class(cls):
103+
cls._coll, cls._vs = setup_test()
104+
105+
@classmethod
106+
def teardown_class(cls):
107+
cls._coll.database.client.close()
108+
94109
@property
95110
def retriever_constructor(self) -> Type[MongoDBAtlasHybridSearchRetriever]:
96111
"""Get a retriever for integration tests."""
97112
return MongoDBAtlasHybridSearchRetriever
98113

99114
@property
100115
def retriever_constructor_params(self) -> dict:
101-
coll, vs = setup_test()
102116
return {
103-
"vectorstore": vs,
104-
"collection": coll,
117+
"vectorstore": self._vs,
118+
"collection": self._coll,
105119
"search_index_name": SEARCH_INDEX_NAME,
106120
"search_field": PAGE_CONTENT_FIELD,
107121
}

libs/langchain-mongodb/tests/integration_tests/test_tools.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,55 +12,87 @@
1212

1313

1414
class TestQueryMongoDBDatabaseToolIntegration(ToolsIntegrationTests):
15+
@classmethod
16+
def setup_class(cls):
17+
cls._db = create_database()
18+
19+
@classmethod
20+
def teardown_class(cls):
21+
cls._db.close()
22+
1523
@property
1624
def tool_constructor(self) -> Type[QueryMongoDBDatabaseTool]:
1725
return QueryMongoDBDatabaseTool
1826

1927
@property
2028
def tool_constructor_params(self) -> dict:
21-
return dict(db=create_database())
29+
return dict(db=self._db)
2230

2331
@property
2432
def tool_invoke_params_example(self) -> dict:
2533
return dict(query='db.test.aggregate([{"$match": {}}])')
2634

2735

2836
class TestInfoMongoDBDatabaseToolIntegration(ToolsIntegrationTests):
37+
@classmethod
38+
def setup_class(cls):
39+
cls._db = create_database()
40+
41+
@classmethod
42+
def teardown_class(cls):
43+
cls._db.close()
44+
2945
@property
3046
def tool_constructor(self) -> Type[InfoMongoDBDatabaseTool]:
3147
return InfoMongoDBDatabaseTool
3248

3349
@property
3450
def tool_constructor_params(self) -> dict:
35-
return dict(db=create_database())
51+
return dict(db=self._db)
3652

3753
@property
3854
def tool_invoke_params_example(self) -> dict:
3955
return dict(collection_names="test")
4056

4157

4258
class TestListMongoDBDatabaseToolIntegration(ToolsIntegrationTests):
59+
@classmethod
60+
def setup_class(cls):
61+
cls._db = create_database()
62+
63+
@classmethod
64+
def teardown_class(cls):
65+
cls._db.close()
66+
4367
@property
4468
def tool_constructor(self) -> Type[ListMongoDBDatabaseTool]:
4569
return ListMongoDBDatabaseTool
4670

4771
@property
4872
def tool_constructor_params(self) -> dict:
49-
return dict(db=create_database())
73+
return dict(db=self._db)
5074

5175
@property
5276
def tool_invoke_params_example(self) -> dict:
5377
return dict()
5478

5579

5680
class TestQueryMongoDBCheckerToolIntegration(ToolsIntegrationTests):
81+
@classmethod
82+
def setup_class(cls):
83+
cls._db = create_database()
84+
85+
@classmethod
86+
def teardown_class(cls):
87+
cls._db.close()
88+
5789
@property
5890
def tool_constructor(self) -> Type[QueryMongoDBCheckerTool]:
5991
return QueryMongoDBCheckerTool
6092

6193
@property
6294
def tool_constructor_params(self) -> dict:
63-
return dict(db=create_database(), llm=create_llm())
95+
return dict(db=self._db, llm=create_llm())
6496

6597
@property
6698
def tool_invoke_params_example(self) -> dict:

0 commit comments

Comments
 (0)