Skip to content

Commit 888edf4

Browse files
committed
Merge branch 'development' of github.com:Scale3-Labs/langtrace-python-sdk into development
2 parents 85b15e4 + 9c4b2a9 commit 888edf4

File tree

17 files changed

+636
-6
lines changed

17 files changed

+636
-6
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
class CerebrasRunner:
2+
def run(self):
3+
from .main import (
4+
completion_example,
5+
completion_with_tools_example,
6+
openai_cerebras_example,
7+
)
8+
9+
completion_with_tools_example()
10+
completion_example()
11+
openai_cerebras_example()
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from langtrace_python_sdk import langtrace
2+
from cerebras.cloud.sdk import Cerebras
3+
from dotenv import load_dotenv
4+
import re
5+
import json
6+
from openai import OpenAI
7+
import os
8+
9+
load_dotenv()
10+
11+
langtrace.init()
12+
openai_client = OpenAI(
13+
base_url="https://api.cerebras.ai/v1",
14+
api_key=os.getenv("CEREBRAS_API_KEY"),
15+
)
16+
client = Cerebras()
17+
18+
19+
def openai_cerebras_example(stream=False):
20+
completion = openai_client.chat.completions.create(
21+
messages=[
22+
{
23+
"role": "user",
24+
"content": "Why is fast inference important?",
25+
}
26+
],
27+
model="llama3.1-8b",
28+
stream=stream,
29+
)
30+
31+
if stream:
32+
for chunk in completion:
33+
print(chunk)
34+
else:
35+
return completion
36+
37+
38+
def completion_example(stream=False):
39+
completion = client.chat.completions.create(
40+
messages=[
41+
{
42+
"role": "user",
43+
"content": "Why is fast inference important?",
44+
}
45+
],
46+
model="llama3.1-8b",
47+
stream=stream,
48+
)
49+
50+
if stream:
51+
for chunk in completion:
52+
print(chunk)
53+
else:
54+
return completion
55+
56+
57+
def completion_with_tools_example(stream=False):
58+
messages = [
59+
{
60+
"role": "system",
61+
"content": "You are a helpful assistant with access to a calculator. Use the calculator tool to compute mathematical expressions when needed.",
62+
},
63+
{"role": "user", "content": "What's the result of 15 multiplied by 7?"},
64+
]
65+
66+
response = client.chat.completions.create(
67+
model="llama3.1-8b",
68+
messages=messages,
69+
tools=tools,
70+
stream=stream,
71+
)
72+
73+
if stream:
74+
# Handle streaming response
75+
full_content = ""
76+
for chunk in response:
77+
if chunk.choices[0].delta.tool_calls:
78+
tool_call = chunk.choices[0].delta.tool_calls[0]
79+
if hasattr(tool_call, "function"):
80+
if tool_call.function.name == "calculate":
81+
arguments = json.loads(tool_call.function.arguments)
82+
result = calculate(arguments["expression"])
83+
print(f"Calculation result: {result}")
84+
85+
# Get final response with calculation result
86+
messages.append(
87+
{
88+
"role": "assistant",
89+
"content": None,
90+
"tool_calls": [
91+
{
92+
"function": {
93+
"name": "calculate",
94+
"arguments": tool_call.function.arguments,
95+
},
96+
"id": tool_call.id,
97+
"type": "function",
98+
}
99+
],
100+
}
101+
)
102+
messages.append(
103+
{
104+
"role": "tool",
105+
"content": str(result),
106+
"tool_call_id": tool_call.id,
107+
}
108+
)
109+
110+
final_response = client.chat.completions.create(
111+
model="llama3.1-70b", messages=messages, stream=True
112+
)
113+
114+
for final_chunk in final_response:
115+
if final_chunk.choices[0].delta.content:
116+
print(final_chunk.choices[0].delta.content, end="")
117+
elif chunk.choices[0].delta.content:
118+
print(chunk.choices[0].delta.content, end="")
119+
full_content += chunk.choices[0].delta.content
120+
else:
121+
# Handle non-streaming response
122+
choice = response.choices[0].message
123+
if choice.tool_calls:
124+
function_call = choice.tool_calls[0].function
125+
if function_call.name == "calculate":
126+
arguments = json.loads(function_call.arguments)
127+
result = calculate(arguments["expression"])
128+
print(f"Calculation result: {result}")
129+
130+
messages.append(
131+
{
132+
"role": "assistant",
133+
"content": None,
134+
"tool_calls": [
135+
{
136+
"function": {
137+
"name": "calculate",
138+
"arguments": function_call.arguments,
139+
},
140+
"id": choice.tool_calls[0].id,
141+
"type": "function",
142+
}
143+
],
144+
}
145+
)
146+
messages.append(
147+
{
148+
"role": "tool",
149+
"content": str(result),
150+
"tool_call_id": choice.tool_calls[0].id,
151+
}
152+
)
153+
154+
final_response = client.chat.completions.create(
155+
model="llama3.1-70b",
156+
messages=messages,
157+
)
158+
159+
if final_response:
160+
print(final_response.choices[0].message.content)
161+
else:
162+
print("No final response received")
163+
else:
164+
print("Unexpected response from the model")
165+
166+
167+
def calculate(expression):
168+
expression = re.sub(r"[^0-9+\-*/().]", "", expression)
169+
170+
try:
171+
result = eval(expression)
172+
return str(result)
173+
except (SyntaxError, ZeroDivisionError, NameError, TypeError, OverflowError):
174+
return "Error: Invalid expression"
175+
176+
177+
tools = [
178+
{
179+
"type": "function",
180+
"function": {
181+
"name": "calculate",
182+
"description": "A calculator tool that can perform basic arithmetic operations. Use this when you need to compute mathematical expressions or solve numerical problems.",
183+
"parameters": {
184+
"type": "object",
185+
"properties": {
186+
"expression": {
187+
"type": "string",
188+
"description": "The mathematical expression to evaluate",
189+
}
190+
},
191+
"required": ["expression"],
192+
},
193+
},
194+
}
195+
]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from langtrace_python_sdk import langtrace, with_langtrace_root_span
2+
import pymongo
3+
import os
4+
from dotenv import load_dotenv
5+
from openai import OpenAI
6+
7+
load_dotenv()
8+
langtrace.init(write_spans_to_console=False, batch=False)
9+
MODEL = "text-embedding-ada-002"
10+
openai_client = OpenAI()
11+
client = pymongo.MongoClient(os.environ["MONGO_URI"])
12+
13+
14+
# Define a function to generate embeddings
15+
def get_embedding(text):
16+
"""Generates vector embeddings for the given text."""
17+
embedding = (
18+
openai_client.embeddings.create(input=[text], model=MODEL).data[0].embedding
19+
)
20+
return embedding
21+
22+
23+
@with_langtrace_root_span("mongo-vector-search")
24+
def vector_query():
25+
db = client["sample_mflix"]
26+
27+
embedded_movies_collection = db["embedded_movies"]
28+
# define pipeline
29+
pipeline = [
30+
{
31+
"$vectorSearch": {
32+
"index": "vector_index",
33+
"path": "plot_embedding",
34+
"queryVector": get_embedding("time travel"),
35+
"numCandidates": 150,
36+
"limit": 10,
37+
}
38+
},
39+
{
40+
"$project": {
41+
"_id": 0,
42+
"plot": 1,
43+
"title": 1,
44+
"score": {"$meta": "vectorSearchScore"},
45+
}
46+
},
47+
]
48+
49+
result = embedded_movies_collection.aggregate(pipeline)
50+
for doc in result:
51+
# print(doc)
52+
pass
53+
54+
55+
if __name__ == "__main__":
56+
try:
57+
vector_query()
58+
except Exception as e:
59+
print("error", e)
60+
finally:
61+
client.close()

src/langtrace_python_sdk/constants/instrumentation/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
"EMBEDCHAIN": "Embedchain",
3535
"AUTOGEN": "Autogen",
3636
"XAI": "XAI",
37+
"MONGODB": "MongoDB",
3738
"AWS_BEDROCK": "AWS Bedrock",
39+
"CEREBRAS": "Cerebras",
3840
}
3941

4042
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY = "langtrace_additional_attributes"
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
APIS = {
2+
"AGGREGATE": {
3+
"MODULE": "pymongo.collection",
4+
"METHOD": "Collection.aggregate",
5+
"OPERATION": "aggregate",
6+
"SPAN_NAME": "MongoDB Aggregate",
7+
},
8+
}

src/langtrace_python_sdk/instrumentation/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from .aws_bedrock import AWSBedrockInstrumentation
2222
from .embedchain import EmbedchainInstrumentation
2323
from .litellm import LiteLLMInstrumentation
24+
from .pymongo import PyMongoInstrumentation
25+
from .cerebras import CerebrasInstrumentation
2426

2527
__all__ = [
2628
"AnthropicInstrumentation",
@@ -45,5 +47,7 @@
4547
"VertexAIInstrumentation",
4648
"GeminiInstrumentation",
4749
"MistralInstrumentation",
50+
"PyMongoInstrumentation",
4851
"AWSBedrockInstrumentation",
52+
"CerebrasInstrumentation",
4953
]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .instrumentation import CerebrasInstrumentation
2+
3+
__all__ = ["CerebrasInstrumentation"]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""
2+
Copyright (c) 2024 Scale3 Labs
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Collection
18+
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
19+
from opentelemetry.trace import get_tracer
20+
from opentelemetry.semconv.schemas import Schemas
21+
from wrapt import wrap_function_wrapper
22+
from importlib_metadata import version as v
23+
from .patch import chat_completions_create, async_chat_completions_create
24+
25+
26+
class CerebrasInstrumentation(BaseInstrumentor):
27+
"""
28+
The CerebrasInstrumentation class represents the Cerebras instrumentation
29+
"""
30+
31+
def instrumentation_dependencies(self) -> Collection[str]:
32+
return ["cerebras-cloud-sdk >= 1.0.0"]
33+
34+
def _instrument(self, **kwargs):
35+
tracer_provider = kwargs.get("tracer_provider")
36+
tracer = get_tracer(
37+
__name__, "", tracer_provider, schema_url=Schemas.V1_27_0.value
38+
)
39+
version = v("cerebras-cloud-sdk")
40+
41+
wrap_function_wrapper(
42+
module="cerebras.cloud.sdk",
43+
name="resources.chat.completions.CompletionsResource.create",
44+
wrapper=chat_completions_create(version, tracer),
45+
)
46+
47+
wrap_function_wrapper(
48+
module="cerebras.cloud.sdk",
49+
name="resources.chat.completions.AsyncCompletionsResource.create",
50+
wrapper=async_chat_completions_create(version, tracer),
51+
)
52+
53+
def _uninstrument(self, **kwargs):
54+
pass

0 commit comments

Comments
 (0)