Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/examples/mongo_vector_search_example/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from langtrace_python_sdk import langtrace, with_langtrace_root_span
import pymongo
import os
from dotenv import load_dotenv
from openai import OpenAI

load_dotenv()
langtrace.init(write_spans_to_console=False, batch=False)
MODEL = "text-embedding-ada-002"
openai_client = OpenAI()
client = pymongo.MongoClient(os.environ["MONGO_URI"])


# Define a function to generate embeddings
def get_embedding(text):
"""Generates vector embeddings for the given text."""
embedding = (
openai_client.embeddings.create(input=[text], model=MODEL).data[0].embedding
)
return embedding


@with_langtrace_root_span("mongo-vector-search")
def vector_query():
db = client["sample_mflix"]

embedded_movies_collection = db["embedded_movies"]
# define pipeline
pipeline = [
{
"$vectorSearch": {
"index": "vector_index",
"path": "plot_embedding",
"queryVector": get_embedding("time travel"),
"numCandidates": 150,
"limit": 10,
}
},
{
"$project": {
"_id": 0,
"plot": 1,
"title": 1,
"score": {"$meta": "vectorSearchScore"},
}
},
]

result = embedded_movies_collection.aggregate(pipeline)
for doc in result:
# print(doc)
pass


if __name__ == "__main__":
try:
vector_query()
except Exception as e:
print("error", e)
finally:
client.close()
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"EMBEDCHAIN": "Embedchain",
"AUTOGEN": "Autogen",
"XAI": "XAI",
"MONGODB": "MongoDB",
"AWS_BEDROCK": "AWS Bedrock",
}

Expand Down
8 changes: 8 additions & 0 deletions src/langtrace_python_sdk/constants/instrumentation/pymongo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
APIS = {
"AGGREGATE": {
"MODULE": "pymongo.collection",
"METHOD": "Collection.aggregate",
"OPERATION": "aggregate",
"SPAN_NAME": "MongoDB Aggregate",
},
}
2 changes: 2 additions & 0 deletions src/langtrace_python_sdk/instrumentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .aws_bedrock import AWSBedrockInstrumentation
from .embedchain import EmbedchainInstrumentation
from .litellm import LiteLLMInstrumentation
from .pymongo import PyMongoInstrumentation

__all__ = [
"AnthropicInstrumentation",
Expand All @@ -45,5 +46,6 @@
"VertexAIInstrumentation",
"GeminiInstrumentation",
"MistralInstrumentation",
"PyMongoInstrumentation",
"AWSBedrockInstrumentation",
]
14 changes: 14 additions & 0 deletions src/langtrace_python_sdk/instrumentation/openai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
set_event_completion,
StreamWrapper,
set_span_attributes,
set_usage_attributes,
)
from langtrace_python_sdk.types import NOT_GIVEN

Expand Down Expand Up @@ -450,6 +451,14 @@ def traced_method(
span_attributes[SpanAttributes.LLM_REQUEST_EMBEDDING_INPUTS] = json.dumps(
[kwargs.get("input", "")]
)
span_attributes[SpanAttributes.LLM_PROMPTS] = json.dumps(
[
{
"role": "user",
"content": kwargs.get("input"),
}
]
)

attributes = LLMSpanAttributes(**filter_valid_attributes(span_attributes))

Expand All @@ -463,6 +472,11 @@ def traced_method(
try:
# Attempt to call the original method
result = wrapped(*args, **kwargs)
usage = getattr(result, "usage", None)
if usage:
set_usage_attributes(
span, {"prompt_tokens": getattr(usage, "prompt_tokens", 0)}
)
span.set_status(StatusCode.OK)
return result
except Exception as err:
Expand Down
5 changes: 5 additions & 0 deletions src/langtrace_python_sdk/instrumentation/pymongo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .instrumentation import PyMongoInstrumentation

__all__ = [
"PyMongoInstrumentation",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Copyright (c) 2024 Scale3 Labs

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.trace import get_tracer

from typing import Collection
from importlib_metadata import version as v
from wrapt import wrap_function_wrapper as _W
from .patch import generic_patch
from langtrace_python_sdk.constants.instrumentation.pymongo import APIS


class PyMongoInstrumentation(BaseInstrumentor):
"""
The PyMongoInstrumentation class represents the PyMongo instrumentation
"""

def instrumentation_dependencies(self) -> Collection[str]:
return ["pymongo >= 4.0.0"]

def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, "", tracer_provider)
version = v("pymongo")
for api in APIS.values():
_W(
module=api["MODULE"],
name=api["METHOD"],
wrapper=generic_patch(api["SPAN_NAME"], version, tracer),
)

def _uninstrument(self, **kwargs):
pass
73 changes: 73 additions & 0 deletions src/langtrace_python_sdk/instrumentation/pymongo/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from langtrace_python_sdk.utils.llm import (
get_langtrace_attributes,
get_span_name,
set_span_attributes,
set_span_attribute,
)
from langtrace_python_sdk.utils import deduce_args_and_kwargs
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.trace import SpanKind
from langtrace_python_sdk.constants.instrumentation.common import SERVICE_PROVIDERS
from langtrace.trace_attributes import DatabaseSpanAttributes

import json


def generic_patch(name, version, tracer):
def traced_method(wrapped, instance, args, kwargs):
database = instance.database.__dict__
span_attributes = {
**get_langtrace_attributes(
version=version,
service_provider=SERVICE_PROVIDERS["MONGODB"],
vendor_type="vectordb",
),
"db.system": "mongodb",
"db.query": "aggregate",
}

attributes = DatabaseSpanAttributes(**span_attributes)

with tracer.start_as_current_span(
get_span_name(name), kind=SpanKind.CLIENT
) as span:
if span.is_recording():
set_input_attributes(
span, deduce_args_and_kwargs(wrapped, *args, **kwargs)
)
set_span_attributes(span, attributes)

try:
result = wrapped(*args, **kwargs)
print(result)
for doc in result:
if span.is_recording():
span.add_event(
name="db.query.match",
attributes={**doc},
)
return result
except Exception as err:
# Record the exception in the span
span.record_exception(err)

# Set the span status to indicate an error
span.set_status(Status(StatusCode.ERROR, str(err)))

# Reraise the exception to ensure it's not swallowed
raise

return traced_method


def set_input_attributes(span, args):
pipeline = args.get("pipeline", None)
for stage in pipeline:
for k, v in stage.items():
if k == "$vectorSearch":
set_span_attribute(span, "db.index", v.get("index", None))
set_span_attribute(span, "db.path", v.get("path", None))
set_span_attribute(span, "db.top_k", v.get("numCandidates"))
set_span_attribute(span, "db.limit", v.get("limit"))
else:
set_span_attribute(span, k, json.dumps(v))
5 changes: 4 additions & 1 deletion src/langtrace_python_sdk/langtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
AutogenInstrumentation,
VertexAIInstrumentation,
WeaviateInstrumentation,
PyMongoInstrumentation,
)
from opentelemetry.util.re import parse_env_headers

Expand All @@ -75,6 +76,7 @@
validate_instrumentations,
)
from langtrace_python_sdk.utils.langtrace_sampler import LangtraceSampler
from langtrace_python_sdk.extensions.langtrace_exporter import LangTraceExporter
from sentry_sdk.types import Event, Hint

logging.disable(level=logging.INFO)
Expand Down Expand Up @@ -150,7 +152,7 @@ def get_exporter(config: LangtraceConfig, host: str):
headers = get_headers(config)
host = f"{host}/api/trace" if host == LANGTRACE_REMOTE_URL else host
if "http" in host.lower() or "https" in host.lower():
return HTTPExporter(endpoint=host, headers=headers)
return LangTraceExporter(host, config.api_key, config.disable_logging)
else:
return GRPCExporter(endpoint=host, headers=headers)

Expand Down Expand Up @@ -280,6 +282,7 @@ def init(
"mistralai": MistralInstrumentation(),
"boto3": AWSBedrockInstrumentation(),
"autogen": AutogenInstrumentation(),
"pymongo": PyMongoInstrumentation(),
}

init_instrumentations(config.disable_instrumentations, all_instrumentations)
Expand Down
Loading