Skip to content

Commit 88a0858

Browse files
INTPYTHON-504 Add DriverInfo to MongoClients (#73)
1 parent ccdef94 commit 88a0858

File tree

8 files changed

+42
-22
lines changed

8 files changed

+42
-22
lines changed

libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import logging
3+
from importlib.metadata import version
34
from typing import Dict, List, Optional
45

56
from langchain_core.chat_history import BaseChatMessageHistory
@@ -9,6 +10,7 @@
910
messages_from_dict,
1011
)
1112
from pymongo import MongoClient, errors
13+
from pymongo.driver_info import DriverInfo
1214

1315
logger = logging.getLogger(__name__)
1416

@@ -112,7 +114,12 @@ def __init__(
112114
self.client = client
113115
elif connection_string:
114116
try:
115-
self.client = MongoClient(connection_string)
117+
self.client = MongoClient(
118+
connection_string,
119+
driver=DriverInfo(
120+
name="Langchain", version=version("langchain-mongodb")
121+
),
122+
)
116123
except errors.ConnectionFailure as error:
117124
logger.error(error)
118125
else:

libs/langchain-mongodb/langchain_mongodb/indexes.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
import functools
55
import warnings
6+
from importlib.metadata import version
67
from typing import Any, Dict, List, Optional, Sequence
78

89
from langchain_core.indexing.base import RecordManager
910
from langchain_core.runnables.config import run_in_executor
1011
from pymongo import MongoClient
1112
from pymongo.collection import Collection
13+
from pymongo.driver_info import DriverInfo
1214
from pymongo.errors import OperationFailure
1315

1416

@@ -47,7 +49,10 @@ def from_connection_string(
4749
Returns:
4850
A new MongoDBRecordManager instance.
4951
"""
50-
client: MongoClient = MongoClient(connection_string)
52+
client: MongoClient = MongoClient(
53+
connection_string,
54+
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
55+
)
5156
db_name, collection_name = namespace.split(".")
5257
collection = client[db_name][collection_name]
5358
return cls(collection=collection)

libs/langchain-mongodb/langchain_mongodb/loaders.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
from __future__ import annotations
33

44
import logging
5+
from importlib.metadata import version
56
from typing import Dict, List, Optional, Sequence
67

78
from langchain_community.document_loaders.base import BaseLoader
89
from langchain_core.documents import Document
910
from langchain_core.runnables.config import run_in_executor
1011
from pymongo import MongoClient
1112
from pymongo.collection import Collection
13+
from pymongo.driver_info import DriverInfo
1214

1315
logger = logging.getLogger(__name__)
1416

@@ -80,7 +82,10 @@ def from_connection_string(
8082
include_db_collection_in_metadata (bool): Flag to include database and
8183
collection names in metadata.
8284
"""
83-
client = MongoClient(connection_string)
85+
client = MongoClient(
86+
connection_string,
87+
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
88+
)
8489
collection = client[db_name][collection_name]
8590
return MongoDBLoader(
8691
collection,

libs/langchain-mongodb/langchain_mongodb/retrievers/parent_document.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def from_connection_string(
168168
"""
169169
client: MongoClient = MongoClient(
170170
connection_string,
171-
driver=DriverInfo(name="langchain", version=version("langchain-mongodb")),
171+
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
172172
)
173173
collection = client[database_name][collection_name]
174174
vectorstore = MongoDBAtlasVectorSearch(

libs/langchain-mongodb/langchain_mongodb/vectorstores.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
7979
.. code-block:: python
8080
8181
import getpass
82-
MONGODB_ATLAS_CLUSTER_URI = getpass.getpass("MongoDB Atlas Cluster URI:")
82+
MONGODB_ATLAS_CONNECTION_STRING = getpass.getpass("MongoDB Atlas Connection String:")
8383
8484
Key init args — indexing params:
8585
embedding: Embeddings
@@ -99,20 +99,11 @@ class MongoDBAtlasVectorSearch(VectorStore):
9999
from pymongo import MongoClient
100100
from langchain_openai import OpenAIEmbeddings
101101
102-
# initialize MongoDB python client
103-
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
104-
105-
DB_NAME = "langchain_test_db"
106-
COLLECTION_NAME = "langchain_test_vectorstores"
107-
ATLAS_VECTOR_SEARCH_INDEX_NAME = "langchain-test-index-vectorstores"
108-
109-
MONGODB_COLLECTION = client[DB_NAME][COLLECTION_NAME]
110-
111-
vector_store = MongoDBAtlasVectorSearch(
112-
collection=MONGODB_COLLECTION,
102+
vector_store = MongoDBAtlasVectorSearch.from_connection_string(
103+
connection_string=os=MONGODB_ATLAS_CONNECTION_STRING,
104+
namespace="db_name.collection_name",
113105
embedding=OpenAIEmbeddings(),
114-
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
115-
relevance_score_fn="cosine",
106+
index_name="vector_index",
116107
)
117108
118109
Add Documents:
@@ -279,7 +270,7 @@ def from_connection_string(
279270
"""
280271
client: MongoClient = MongoClient(
281272
connection_string,
282-
driver=DriverInfo(name="Langchain", version=version("langchain")),
273+
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
283274
)
284275
db_name, collection_name = namespace.split(".")
285276
collection = client[db_name][collection_name]

libs/langchain-mongodb/tests/integration_tests/test_parent_document.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_1clxn_retriever(
3333
# Setup
3434
client: MongoClient = MongoClient(
3535
connection_string,
36-
driver=DriverInfo(name="langchain", version=version("langchain-mongodb")),
36+
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
3737
)
3838
db = client[DB_NAME]
3939
combined_clxn = db[COLLECTION_NAME]

libs/langchain-mongodb/tests/unit_tests/test_chat_message_histories.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import json
2+
from importlib.metadata import version
23

34
import mongomock
45
import pytest
56
from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
67
from langchain_core.messages import message_to_dict
8+
from pymongo.driver_info import DriverInfo
79
from pytest_mock import MockerFixture
810

911
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
@@ -59,7 +61,10 @@ def test_init_with_connection_string(mocker: MockerFixture) -> None:
5961
collection_name="test-collection",
6062
)
6163

62-
mock_mongo_client.assert_called_once_with("mongodb://localhost:27017/")
64+
mock_mongo_client.assert_called_once_with(
65+
"mongodb://localhost:27017/",
66+
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
67+
)
6368
assert history.session_id == "test-session"
6469
assert history.database_name == "test-database"
6570
assert history.collection_name == "test-collection"

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Iterator, Sequence
22
from contextlib import contextmanager
3+
from importlib.metadata import version
34
from typing import (
45
Any,
56
Optional,
@@ -8,6 +9,7 @@
89
from langchain_core.runnables import RunnableConfig
910
from pymongo import MongoClient, UpdateOne
1011
from pymongo.database import Database as MongoDatabase
12+
from pymongo.driver_info import DriverInfo
1113

1214
from langgraph.checkpoint.base import (
1315
WRITES_IDX_MAP,
@@ -88,7 +90,12 @@ def from_conn_string(
8890
"""
8991
client: Optional[MongoClient] = None
9092
try:
91-
client = MongoClient(conn_string)
93+
client = MongoClient(
94+
conn_string,
95+
driver=DriverInfo(
96+
name="Langgraph", version=version("langgraph-checkpoint-mongodb")
97+
),
98+
)
9299
yield MongoDBSaver(
93100
client,
94101
db_name,

0 commit comments

Comments
 (0)