Skip to content

Commit e6d8542

Browse files
authored
chore: add back openai tool choice arg (#245)
* chore: add back openai tool choice arg * style: fix formating
1 parent a9d3400 commit e6d8542

File tree

4 files changed

+81
-2
lines changed

4 files changed

+81
-2
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Example taken from https://platform.openai.com/docs/guides/function-calling
2+
import json
3+
4+
from dotenv import find_dotenv, load_dotenv
5+
from openai import OpenAI
6+
7+
from langtrace_python_sdk import langtrace
8+
9+
client = OpenAI()
10+
11+
_ = load_dotenv(find_dotenv())
12+
13+
langtrace.init(
14+
write_spans_to_console=True,
15+
)
16+
17+
18+
# Example dummy function hard coded to return the same weather
19+
# In production, this could be your backend API or an external API
20+
def get_current_weather(location, unit="fahrenheit"):
21+
"""Get the current weather in a given location"""
22+
if "tokyo" in location.lower():
23+
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit})
24+
elif "san francisco" in location.lower():
25+
return json.dumps(
26+
{"location": "San Francisco", "temperature": "72", "unit": unit}
27+
)
28+
elif "paris" in location.lower():
29+
return json.dumps({"location": "Paris", "temperature": "22", "unit": unit})
30+
else:
31+
return json.dumps({"location": location, "temperature": "unknown"})
32+
33+
34+
def run_conversation():
35+
# Step 1: send the conversation and available functions to the model
36+
messages = [
37+
{
38+
"role": "user",
39+
"content": "What's the weather like in San Francisco, Tokyo, and Paris?",
40+
}
41+
]
42+
tools = [
43+
{
44+
"type": "function",
45+
"function": {
46+
"name": "get_current_weather",
47+
"description": "Get the current weather in a given location",
48+
"parameters": {
49+
"type": "object",
50+
"properties": {
51+
"location": {
52+
"type": "string",
53+
"description": "The city and state, e.g. San Francisco, CA",
54+
},
55+
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
56+
},
57+
"required": ["location"],
58+
},
59+
},
60+
}
61+
]
62+
response = client.chat.completions.create(
63+
model="gpt-4o",
64+
messages=messages,
65+
tools=tools,
66+
tool_choice="required", # auto is default, but we'll be explicit
67+
)
68+
response_message = response.choices[0].message
69+
tool_calls = response_message.tool_calls
70+
print(tool_calls)
71+
72+
73+
print(run_conversation())

src/langtrace_python_sdk/instrumentation/anthropic/patch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def traced_method(wrapped, instance, args, kwargs):
4848
prompts = kwargs.get("messages", [])
4949
system = kwargs.get("system")
5050
if system:
51-
prompts = [{"role": "system", "content": system}] + kwargs.get("messages", [])
51+
prompts = [{"role": "system", "content": system}] + kwargs.get(
52+
"messages", []
53+
)
5254

5355
span_attributes = {
5456
**get_langtrace_attributes(version, service_provider),

src/langtrace_python_sdk/instrumentation/gemini/patch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,10 @@ def get_llm_model(instance):
110110

111111
def serialize_prompts(args, kwargs, instance):
112112
prompts = []
113-
if hasattr(instance, "_system_instruction") and instance._system_instruction is not None:
113+
if (
114+
hasattr(instance, "_system_instruction")
115+
and instance._system_instruction is not None
116+
):
114117
system_prompt = {
115118
"role": "system",
116119
"content": instance._system_instruction.__dict__["_pb"].parts[0].text,

src/langtrace_python_sdk/utils/llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def get_llm_request_attributes(kwargs, prompts=None, model=None):
123123
SpanAttributes.LLM_FREQUENCY_PENALTY: kwargs.get("frequency_penalty"),
124124
SpanAttributes.LLM_REQUEST_SEED: kwargs.get("seed"),
125125
SpanAttributes.LLM_TOOLS: json.dumps(tools) if tools else None,
126+
SpanAttributes.LLM_TOOL_CHOICE: kwargs.get("tool_choice"),
126127
SpanAttributes.LLM_REQUEST_LOGPROPS: kwargs.get("logprobs"),
127128
SpanAttributes.LLM_REQUEST_LOGITBIAS: kwargs.get("logit_bias"),
128129
SpanAttributes.LLM_REQUEST_TOP_LOGPROPS: kwargs.get("top_logprobs"),

0 commit comments

Comments
 (0)