Skip to content

Commit 9c4b2a9

Browse files
authored
Merge pull request #379 from Scale3-Labs/ali/s3en-2856-support-for-mongodb-vector-db
mongodb kickstart + openai embedding enrichment
2 parents 8c56527 + cf1b758 commit 9c4b2a9

File tree

10 files changed

+207
-1
lines changed

10 files changed

+207
-1
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from langtrace_python_sdk import langtrace, with_langtrace_root_span
2+
import pymongo
3+
import os
4+
from dotenv import load_dotenv
5+
from openai import OpenAI
6+
7+
load_dotenv()
8+
langtrace.init(write_spans_to_console=False, batch=False)
9+
MODEL = "text-embedding-ada-002"
10+
openai_client = OpenAI()
11+
client = pymongo.MongoClient(os.environ["MONGO_URI"])
12+
13+
14+
# Define a function to generate embeddings
15+
def get_embedding(text):
16+
"""Generates vector embeddings for the given text."""
17+
embedding = (
18+
openai_client.embeddings.create(input=[text], model=MODEL).data[0].embedding
19+
)
20+
return embedding
21+
22+
23+
@with_langtrace_root_span("mongo-vector-search")
24+
def vector_query():
25+
db = client["sample_mflix"]
26+
27+
embedded_movies_collection = db["embedded_movies"]
28+
# define pipeline
29+
pipeline = [
30+
{
31+
"$vectorSearch": {
32+
"index": "vector_index",
33+
"path": "plot_embedding",
34+
"queryVector": get_embedding("time travel"),
35+
"numCandidates": 150,
36+
"limit": 10,
37+
}
38+
},
39+
{
40+
"$project": {
41+
"_id": 0,
42+
"plot": 1,
43+
"title": 1,
44+
"score": {"$meta": "vectorSearchScore"},
45+
}
46+
},
47+
]
48+
49+
result = embedded_movies_collection.aggregate(pipeline)
50+
for doc in result:
51+
# print(doc)
52+
pass
53+
54+
55+
if __name__ == "__main__":
56+
try:
57+
vector_query()
58+
except Exception as e:
59+
print("error", e)
60+
finally:
61+
client.close()

src/langtrace_python_sdk/constants/instrumentation/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"EMBEDCHAIN": "Embedchain",
3535
"AUTOGEN": "Autogen",
3636
"XAI": "XAI",
37+
"MONGODB": "MongoDB",
3738
"AWS_BEDROCK": "AWS Bedrock",
3839
"CEREBRAS": "Cerebras",
3940
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
APIS = {
2+
"AGGREGATE": {
3+
"MODULE": "pymongo.collection",
4+
"METHOD": "Collection.aggregate",
5+
"OPERATION": "aggregate",
6+
"SPAN_NAME": "MongoDB Aggregate",
7+
},
8+
}

src/langtrace_python_sdk/instrumentation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .aws_bedrock import AWSBedrockInstrumentation
2222
from .embedchain import EmbedchainInstrumentation
2323
from .litellm import LiteLLMInstrumentation
24+
from .pymongo import PyMongoInstrumentation
2425
from .cerebras import CerebrasInstrumentation
2526

2627
__all__ = [
@@ -46,6 +47,7 @@
4647
"VertexAIInstrumentation",
4748
"GeminiInstrumentation",
4849
"MistralInstrumentation",
50+
"PyMongoInstrumentation",
4951
"AWSBedrockInstrumentation",
5052
"CerebrasInstrumentation",
5153
]

src/langtrace_python_sdk/instrumentation/openai/patch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
set_event_completion,
2828
StreamWrapper,
2929
set_span_attributes,
30+
set_usage_attributes,
3031
)
3132
from langtrace_python_sdk.types import NOT_GIVEN
3233

@@ -450,6 +451,14 @@ def traced_method(
450451
span_attributes[SpanAttributes.LLM_REQUEST_EMBEDDING_INPUTS] = json.dumps(
451452
[kwargs.get("input", "")]
452453
)
454+
span_attributes[SpanAttributes.LLM_PROMPTS] = json.dumps(
455+
[
456+
{
457+
"role": "user",
458+
"content": kwargs.get("input"),
459+
}
460+
]
461+
)
453462

454463
attributes = LLMSpanAttributes(**filter_valid_attributes(span_attributes))
455464

@@ -463,6 +472,11 @@ def traced_method(
463472
try:
464473
# Attempt to call the original method
465474
result = wrapped(*args, **kwargs)
475+
usage = getattr(result, "usage", None)
476+
if usage:
477+
set_usage_attributes(
478+
span, {"prompt_tokens": getattr(usage, "prompt_tokens", 0)}
479+
)
466480
span.set_status(StatusCode.OK)
467481
return result
468482
except Exception as err:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .instrumentation import PyMongoInstrumentation
2+
3+
__all__ = [
4+
"PyMongoInstrumentation",
5+
]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
Copyright (c) 2024 Scale3 Labs
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
18+
from opentelemetry.trace import get_tracer
19+
20+
from typing import Collection
21+
from importlib_metadata import version as v
22+
from wrapt import wrap_function_wrapper as _W
23+
from .patch import generic_patch
24+
from langtrace_python_sdk.constants.instrumentation.pymongo import APIS
25+
26+
27+
class PyMongoInstrumentation(BaseInstrumentor):
28+
"""
29+
The PyMongoInstrumentation class represents the PyMongo instrumentation
30+
"""
31+
32+
def instrumentation_dependencies(self) -> Collection[str]:
33+
return ["pymongo >= 4.0.0"]
34+
35+
def _instrument(self, **kwargs):
36+
tracer_provider = kwargs.get("tracer_provider")
37+
tracer = get_tracer(__name__, "", tracer_provider)
38+
version = v("pymongo")
39+
for api in APIS.values():
40+
_W(
41+
module=api["MODULE"],
42+
name=api["METHOD"],
43+
wrapper=generic_patch(api["SPAN_NAME"], version, tracer),
44+
)
45+
46+
def _uninstrument(self, **kwargs):
47+
pass
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from langtrace_python_sdk.utils.llm import (
2+
get_langtrace_attributes,
3+
get_span_name,
4+
set_span_attributes,
5+
set_span_attribute,
6+
)
7+
from langtrace_python_sdk.utils import deduce_args_and_kwargs, handle_span_error
8+
from opentelemetry.trace import SpanKind
9+
from langtrace_python_sdk.constants.instrumentation.common import SERVICE_PROVIDERS
10+
from langtrace.trace_attributes import DatabaseSpanAttributes
11+
12+
import json
13+
14+
15+
def generic_patch(name, version, tracer):
16+
def traced_method(wrapped, instance, args, kwargs):
17+
database = instance.database.__dict__
18+
span_attributes = {
19+
**get_langtrace_attributes(
20+
version=version,
21+
service_provider=SERVICE_PROVIDERS["MONGODB"],
22+
vendor_type="vectordb",
23+
),
24+
"db.system": "mongodb",
25+
"db.query": "aggregate",
26+
}
27+
28+
attributes = DatabaseSpanAttributes(**span_attributes)
29+
30+
with tracer.start_as_current_span(
31+
get_span_name(name), kind=SpanKind.CLIENT
32+
) as span:
33+
if span.is_recording():
34+
set_input_attributes(
35+
span, deduce_args_and_kwargs(wrapped, *args, **kwargs)
36+
)
37+
set_span_attributes(span, attributes)
38+
39+
try:
40+
result = wrapped(*args, **kwargs)
41+
print(result)
42+
for doc in result:
43+
if span.is_recording():
44+
span.add_event(
45+
name="db.query.match",
46+
attributes={**doc},
47+
)
48+
return result
49+
except Exception as err:
50+
handle_span_error(span, err)
51+
raise
52+
53+
return traced_method
54+
55+
56+
def set_input_attributes(span, args):
57+
pipeline = args.get("pipeline", None)
58+
for stage in pipeline:
59+
for k, v in stage.items():
60+
if k == "$vectorSearch":
61+
set_span_attribute(span, "db.index", v.get("index", None))
62+
set_span_attribute(span, "db.path", v.get("path", None))
63+
set_span_attribute(span, "db.top_k", v.get("numCandidates"))
64+
set_span_attribute(span, "db.limit", v.get("limit"))
65+
else:
66+
set_span_attribute(span, k, json.dumps(v))

src/langtrace_python_sdk/langtrace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
AutogenInstrumentation,
6565
VertexAIInstrumentation,
6666
WeaviateInstrumentation,
67+
PyMongoInstrumentation,
6768
CerebrasInstrumentation,
6869
)
6970
from opentelemetry.util.re import parse_env_headers
@@ -281,6 +282,7 @@ def init(
281282
"mistralai": MistralInstrumentation(),
282283
"boto3": AWSBedrockInstrumentation(),
283284
"autogen": AutogenInstrumentation(),
285+
"pymongo": PyMongoInstrumentation(),
284286
"cerebras-cloud-sdk": CerebrasInstrumentation(),
285287
}
286288

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.3.2"
1+
__version__ = "3.3.3"

0 commit comments

Comments
 (0)