Skip to content

Commit 377c95d

Browse files
committed
mongodb vector search instrument
1 parent 68f6b5c commit 377c95d

File tree

5 files changed

+85
-16
lines changed

5 files changed

+85
-16
lines changed

src/examples/mongo_vector_search_example/main.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from langtrace_python_sdk import langtrace, with_langtrace_root_span
12
import pymongo
23
import os
34
from dotenv import load_dotenv
45
from openai import OpenAI
5-
from langtrace_python_sdk import langtrace
66

77
load_dotenv()
8-
langtrace.init()
8+
langtrace.init(write_spans_to_console=False, batch=False)
99
MODEL = "text-embedding-ada-002"
1010
openai_client = OpenAI()
1111
client = pymongo.MongoClient(os.environ["MONGO_URI"])
@@ -20,22 +20,19 @@ def get_embedding(text):
2020
return embedding
2121

2222

23+
@with_langtrace_root_span("mongo-vector-search")
2324
def vector_query():
2425
db = client["sample_mflix"]
2526

2627
embedded_movies_collection = db["embedded_movies"]
27-
2828
# define pipeline
2929
pipeline = [
3030
{
3131
"$vectorSearch": {
3232
"index": "vector_index",
3333
"path": "plot_embedding",
34-
"queryVector": get_embedding(
35-
"A movie about a hacker that had a really rough childhood and been trying to convince his father otherwise."
36-
),
37-
# "numCandidates": 150,
38-
"exact": True,
34+
"queryVector": get_embedding("time travel"),
35+
"numCandidates": 150,
3936
"limit": 10,
4037
}
4138
},
@@ -51,13 +48,14 @@ def vector_query():
5148

5249
result = embedded_movies_collection.aggregate(pipeline)
5350
for doc in result:
54-
print(doc)
51+
# print(doc)
52+
pass
5553

5654

5755
if __name__ == "__main__":
5856
try:
5957
vector_query()
6058
except Exception as e:
61-
print(e)
59+
print("error", e)
6260
finally:
6361
client.close()

src/langtrace_python_sdk/constants/instrumentation/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"EMBEDCHAIN": "Embedchain",
3535
"AUTOGEN": "Autogen",
3636
"XAI": "XAI",
37+
"MONGODB": "MongoDB",
3738
}
3839

3940
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY = "langtrace_additional_attributes"
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
APIS = {
22
"AGGREGATE": {
3-
"METHOD": "aggregate",
3+
"MODULE": "pymongo.collection",
4+
"METHOD": "Collection.aggregate",
45
"OPERATION": "aggregate",
6+
"SPAN_NAME": "MongoDB Aggregate",
57
},
68
}

src/langtrace_python_sdk/instrumentation/pymongo/instrumentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def _instrument(self, **kwargs):
3838
version = v("pymongo")
3939
for api in APIS.values():
4040
_W(
41-
module="pymongo.collection",
42-
name=f"Collection.{api['METHOD']}",
43-
wrapper=generic_patch(version, tracer),
41+
module=api["MODULE"],
42+
name=api["METHOD"],
43+
wrapper=generic_patch(api["SPAN_NAME"], version, tracer),
4444
)
4545

4646
def _uninstrument(self, **kwargs):
Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,73 @@
1-
def generic_patch(version, tracer):
1+
from langtrace_python_sdk.utils.llm import (
2+
get_langtrace_attributes,
3+
get_span_name,
4+
set_span_attributes,
5+
set_span_attribute,
6+
)
7+
from langtrace_python_sdk.utils import deduce_args_and_kwargs
8+
from opentelemetry.trace.status import Status, StatusCode
9+
from opentelemetry.trace import SpanKind
10+
from langtrace_python_sdk.constants.instrumentation.common import SERVICE_PROVIDERS
11+
from langtrace.trace_attributes import DatabaseSpanAttributes
12+
13+
import json
14+
15+
16+
def generic_patch(name, version, tracer):
217
def traced_method(wrapped, instance, args, kwargs):
3-
wrapped(*args, **kwargs)
18+
database = instance.database.__dict__
19+
span_attributes = {
20+
**get_langtrace_attributes(
21+
version=version,
22+
service_provider=SERVICE_PROVIDERS["MONGODB"],
23+
vendor_type="vectordb",
24+
),
25+
"db.system": "mongodb",
26+
"db.query": "aggregate",
27+
}
28+
29+
attributes = DatabaseSpanAttributes(**span_attributes)
30+
31+
with tracer.start_as_current_span(
32+
get_span_name(name), kind=SpanKind.CLIENT
33+
) as span:
34+
if span.is_recording():
35+
set_input_attributes(
36+
span, deduce_args_and_kwargs(wrapped, *args, **kwargs)
37+
)
38+
set_span_attributes(span, attributes)
39+
40+
try:
41+
result = wrapped(*args, **kwargs)
42+
print(result)
43+
for doc in result:
44+
if span.is_recording():
45+
span.add_event(
46+
name="db.query.match",
47+
attributes={**doc},
48+
)
49+
return result
50+
except Exception as err:
51+
# Record the exception in the span
52+
span.record_exception(err)
53+
54+
# Set the span status to indicate an error
55+
span.set_status(Status(StatusCode.ERROR, str(err)))
56+
57+
# Reraise the exception to ensure it's not swallowed
58+
raise
459

560
return traced_method
61+
62+
63+
def set_input_attributes(span, args):
64+
pipeline = args.get("pipeline", None)
65+
for stage in pipeline:
66+
for k, v in stage.items():
67+
if k == "$vectorSearch":
68+
set_span_attribute(span, "db.index", v.get("index", None))
69+
set_span_attribute(span, "db.path", v.get("path", None))
70+
set_span_attribute(span, "db.top_k", v.get("numCandidates"))
71+
set_span_attribute(span, "db.limit", v.get("limit"))
72+
else:
73+
set_span_attribute(span, k, json.dumps(v))

0 commit comments

Comments
 (0)