Skip to content

Commit f77f5e8

Browse files
committed
Cerebras support
1 parent e67d055 commit f77f5e8

File tree

10 files changed

+318
-21
lines changed

10 files changed

+318
-21
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
class CerebrasRunner:
2+
def run(self):
3+
from .main import completion_example, completion_with_tools_example
4+
5+
completion_with_tools_example()
6+
completion_example()
Lines changed: 151 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from langtrace_python_sdk import langtrace
22
from cerebras.cloud.sdk import Cerebras
33
from dotenv import load_dotenv
4+
import re
5+
import json
46

57
load_dotenv()
68

@@ -9,7 +11,7 @@
911
client = Cerebras()
1012

1113

12-
def completion_example():
14+
def completion_example(stream=False):
1315
completion = client.chat.completions.create(
1416
messages=[
1517
{
@@ -18,5 +20,152 @@ def completion_example():
1820
}
1921
],
2022
model="llama3.1-8b",
23+
stream=stream,
2124
)
22-
return completion
25+
26+
if stream:
27+
for chunk in completion:
28+
print(chunk)
29+
else:
30+
return completion
31+
32+
33+
def completion_with_tools_example(stream=False):
34+
messages = [
35+
{
36+
"role": "system",
37+
"content": "You are a helpful assistant with access to a calculator. Use the calculator tool to compute mathematical expressions when needed.",
38+
},
39+
{"role": "user", "content": "What's the result of 15 multiplied by 7?"},
40+
]
41+
42+
response = client.chat.completions.create(
43+
model="llama3.1-8b",
44+
messages=messages,
45+
tools=tools,
46+
stream=stream,
47+
)
48+
49+
if stream:
50+
# Handle streaming response
51+
full_content = ""
52+
for chunk in response:
53+
if chunk.choices[0].delta.tool_calls:
54+
tool_call = chunk.choices[0].delta.tool_calls[0]
55+
if hasattr(tool_call, "function"):
56+
if tool_call.function.name == "calculate":
57+
arguments = json.loads(tool_call.function.arguments)
58+
result = calculate(arguments["expression"])
59+
print(f"Calculation result: {result}")
60+
61+
# Get final response with calculation result
62+
messages.append(
63+
{
64+
"role": "assistant",
65+
"content": None,
66+
"tool_calls": [
67+
{
68+
"function": {
69+
"name": "calculate",
70+
"arguments": tool_call.function.arguments,
71+
},
72+
"id": tool_call.id,
73+
"type": "function",
74+
}
75+
],
76+
}
77+
)
78+
messages.append(
79+
{
80+
"role": "tool",
81+
"content": str(result),
82+
"tool_call_id": tool_call.id,
83+
}
84+
)
85+
86+
final_response = client.chat.completions.create(
87+
model="llama3.1-70b", messages=messages, stream=True
88+
)
89+
90+
for final_chunk in final_response:
91+
if final_chunk.choices[0].delta.content:
92+
print(final_chunk.choices[0].delta.content, end="")
93+
elif chunk.choices[0].delta.content:
94+
print(chunk.choices[0].delta.content, end="")
95+
full_content += chunk.choices[0].delta.content
96+
else:
97+
# Handle non-streaming response
98+
choice = response.choices[0].message
99+
if choice.tool_calls:
100+
function_call = choice.tool_calls[0].function
101+
if function_call.name == "calculate":
102+
arguments = json.loads(function_call.arguments)
103+
result = calculate(arguments["expression"])
104+
print(f"Calculation result: {result}")
105+
106+
messages.append(
107+
{
108+
"role": "assistant",
109+
"content": None,
110+
"tool_calls": [
111+
{
112+
"function": {
113+
"name": "calculate",
114+
"arguments": function_call.arguments,
115+
},
116+
"id": choice.tool_calls[0].id,
117+
"type": "function",
118+
}
119+
],
120+
}
121+
)
122+
messages.append(
123+
{
124+
"role": "tool",
125+
"content": str(result),
126+
"tool_call_id": choice.tool_calls[0].id,
127+
}
128+
)
129+
130+
final_response = client.chat.completions.create(
131+
model="llama3.1-70b",
132+
messages=messages,
133+
)
134+
135+
if final_response:
136+
print(final_response.choices[0].message.content)
137+
else:
138+
print("No final response received")
139+
else:
140+
print("Unexpected response from the model")
141+
142+
143+
def calculate(expression):
144+
expression = re.sub(r"[^0-9+\-*/().]", "", expression)
145+
146+
try:
147+
result = eval(expression)
148+
return str(result)
149+
except (SyntaxError, ZeroDivisionError, NameError, TypeError, OverflowError):
150+
return "Error: Invalid expression"
151+
152+
153+
tools = [
154+
{
155+
"type": "function",
156+
"function": {
157+
"name": "calculate",
158+
"description": "A calculator tool that can perform basic arithmetic operations. Use this when you need to compute mathematical expressions or solve numerical problems.",
159+
"parameters": {
160+
"type": "object",
161+
"properties": {
162+
"expression": {
163+
"type": "string",
164+
"description": "The mathematical expression to evaluate",
165+
}
166+
},
167+
"required": ["expression"],
168+
},
169+
},
170+
}
171+
]

src/langtrace_python_sdk/constants/instrumentation/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"AUTOGEN": "Autogen",
3636
"XAI": "XAI",
3737
"AWS_BEDROCK": "AWS Bedrock",
38+
"CEREBRAS": "Cerebras",
3839
}
3940

4041
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY = "langtrace_additional_attributes"

src/langtrace_python_sdk/instrumentation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .aws_bedrock import AWSBedrockInstrumentation
2222
from .embedchain import EmbedchainInstrumentation
2323
from .litellm import LiteLLMInstrumentation
24+
from .cerebras import CerebrasInstrumentation
2425

2526
__all__ = [
2627
"AnthropicInstrumentation",
@@ -46,4 +47,5 @@
4647
"GeminiInstrumentation",
4748
"MistralInstrumentation",
4849
"AWSBedrockInstrumentation",
50+
"CerebrasInstrumentation",
4951
]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .instrumentation import CerebrasInstrumentation
2+
3+
__all__ = ["CerebrasInstrumentation"]

src/langtrace_python_sdk/instrumentation/cerebras/instrumentation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Collection
1818
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
1919
from opentelemetry.trace import get_tracer
20+
from opentelemetry.semconv.schemas import Schemas
2021
from wrapt import wrap_function_wrapper
2122
from importlib_metadata import version as v
2223
from .patch import chat_completions_create, async_chat_completions_create
@@ -32,7 +33,9 @@ def instrumentation_dependencies(self) -> Collection[str]:
3233

3334
def _instrument(self, **kwargs):
3435
tracer_provider = kwargs.get("tracer_provider")
35-
tracer = get_tracer(__name__, "", tracer_provider)
36+
tracer = get_tracer(
37+
__name__, "", tracer_provider, schema_url=Schemas.V1_27_0.value
38+
)
3639
version = v("cerebras-cloud-sdk")
3740

3841
wrap_function_wrapper(
@@ -46,3 +49,6 @@ def _instrument(self, **kwargs):
4649
name="resources.chat.completions.AsyncCompletionsResource.create",
4750
wrapper=async_chat_completions_create(version, tracer),
4851
)
52+
53+
def _uninstrument(self, **kwargs):
54+
pass
Lines changed: 126 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,138 @@
1-
"""
2-
Copyright (c) 2024 Scale3 Labs
1+
from langtrace_python_sdk.instrumentation.groq.patch import extract_content
2+
from opentelemetry.trace import SpanKind
3+
from langtrace_python_sdk.utils.llm import (
4+
get_llm_request_attributes,
5+
get_langtrace_attributes,
6+
get_extra_attributes,
7+
get_llm_url,
8+
is_streaming,
9+
set_event_completion,
10+
set_span_attributes,
11+
StreamWrapper,
12+
)
13+
from langtrace_python_sdk.utils.silently_fail import silently_fail
14+
from langtrace_python_sdk.constants.instrumentation.common import SERVICE_PROVIDERS
15+
from langtrace.trace_attributes import SpanAttributes
16+
from langtrace_python_sdk.utils import handle_span_error, set_span_attribute
317

4-
Licensed under the Apache License, Version 2.0 (the "License");
5-
you may not use this file except in compliance with the License.
6-
You may obtain a copy of the License at
718

8-
http://www.apache.org/licenses/LICENSE-2.0
19+
def chat_completions_create(version: str, tracer):
20+
def traced_method(wrapped, instance, args, kwargs):
21+
llm_prompts = []
22+
for message in kwargs.get("messages", []):
23+
llm_prompts.append(message)
924

10-
Unless required by applicable law or agreed to in writing, software
11-
distributed under the License is distributed on an "AS IS" BASIS,
12-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
See the License for the specific language governing permissions and
14-
limitations under the License.
15-
"""
25+
span_attributes = {
26+
**get_langtrace_attributes(version, SERVICE_PROVIDERS["CEREBRAS"]),
27+
**get_llm_request_attributes(kwargs, prompts=llm_prompts),
28+
**get_llm_url(instance),
29+
**get_extra_attributes(),
30+
}
1631

32+
span_name = f"{span_attributes[SpanAttributes.LLM_OPERATION_NAME]} {span_attributes[SpanAttributes.LLM_REQUEST_MODEL]}"
33+
with tracer.start_as_current_span(
34+
name=span_name,
35+
kind=SpanKind.CLIENT,
36+
attributes=span_attributes,
37+
end_on_exit=False,
38+
) as span:
1739

18-
def chat_completions_create(version: str, tracer):
19-
def traced_method(wrapped, instance, args, kwargs):
20-
return wrapped(*args, **kwargs)
40+
try:
41+
_set_input_attributes(span, kwargs, span_attributes)
42+
result = wrapped(*args, **kwargs)
43+
if is_streaming(kwargs):
44+
return StreamWrapper(result, span)
45+
46+
if span.is_recording():
47+
_set_response_attributes(span, result)
48+
span.end()
49+
return result
50+
51+
except Exception as error:
52+
handle_span_error(span, error)
53+
raise
2154

2255
return traced_method
2356

2457

2558
def async_chat_completions_create(version: str, tracer):
26-
def traced_method(wrapped, instance, args, kwargs):
27-
return wrapped(*args, **kwargs)
59+
async def traced_method(wrapped, instance, args, kwargs):
60+
llm_prompts = []
61+
for message in kwargs.get("messages", []):
62+
llm_prompts.append(message)
63+
64+
span_attributes = {
65+
**get_langtrace_attributes(version, SERVICE_PROVIDERS["CEREBRAS"]),
66+
**get_llm_request_attributes(kwargs, prompts=llm_prompts),
67+
**get_llm_url(instance),
68+
**get_extra_attributes(),
69+
}
70+
71+
span_name = f"{span_attributes[SpanAttributes.LLM_OPERATION_NAME]} {span_attributes[SpanAttributes.LLM_REQUEST_MODEL]}"
72+
with tracer.start_as_current_span(
73+
name=span_name,
74+
kind=SpanKind.CLIENT,
75+
attributes=span_attributes,
76+
end_on_exit=False,
77+
) as span:
78+
79+
try:
80+
_set_input_attributes(span, kwargs, span_attributes)
81+
result = await wrapped(*args, **kwargs)
82+
if is_streaming(kwargs):
83+
return StreamWrapper(result, span)
84+
85+
if span.is_recording():
86+
_set_response_attributes(span, result)
87+
span.end()
88+
return result
89+
90+
except Exception as error:
91+
handle_span_error(span, error)
92+
raise
2893

2994
return traced_method
95+
96+
97+
@silently_fail
98+
def _set_response_attributes(span, result):
99+
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, result.model)
100+
101+
if getattr(result, "id", None):
102+
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_ID, result.id)
103+
104+
if getattr(result, "choices", None):
105+
responses = [
106+
{
107+
"role": (
108+
choice.message.role
109+
if choice.message and choice.message.role
110+
else "assistant"
111+
),
112+
"content": extract_content(choice),
113+
**(
114+
{"content_filter_results": choice.content_filter_results}
115+
if hasattr(choice, "content_filter_results")
116+
else {}
117+
),
118+
}
119+
for choice in result.choices
120+
]
121+
set_event_completion(span, responses)
122+
# Get the usage
123+
if getattr(result, "usage", None):
124+
set_span_attribute(
125+
span,
126+
SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
127+
result.usage.prompt_tokens,
128+
)
129+
set_span_attribute(
130+
span,
131+
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
132+
result.usage.completion_tokens,
133+
)
134+
135+
136+
@silently_fail
137+
def _set_input_attributes(span, kwargs, attributes):
138+
set_span_attributes(span, attributes)

0 commit comments

Comments
 (0)