Skip to content

Commit 0a3478d

Browse files
authored
Merge pull request #1013 from parea-ai/PAI-1405-auto-trace-cohere
feat(cohere): auto trace cohere
2 parents 75b69c3 + 58e9595 commit 0a3478d

File tree

13 files changed

+2669
-1512
lines changed

13 files changed

+2669
-1512
lines changed

cookbook/ab_testing.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
from typing import Tuple
2+
13
import os
24
import random
3-
from typing import Tuple
45

56
from openai import OpenAI
67

7-
from parea import trace, trace_insert, Parea, get_current_trace_id
8+
from parea import Parea, get_current_trace_id, trace, trace_insert
89
from parea.schemas import FeedbackRequest
910

1011
client = OpenAI()
1112
# instantiate Parea client
12-
p = Parea(api_key=os.getenv('PAREA_API_KEY'))
13+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
1314
# wrap OpenAI client to trace calls
1415
p.wrap_openai_client(client)
1516

@@ -19,28 +20,32 @@ def generate_email(user: str) -> Tuple[str, str]:
1920
"""Randomly chooses a prompt to perform an A/B test for generating email. Returns the email and the trace ID.
2021
The latter is used to tie-back the collected feedback from the user."""
2122
if random.random() < 0.5:
22-
trace_insert({'metadata': {'ab_test_0': 'variant_0'}})
23-
prompt = f'Generate a long email for {user}'
23+
trace_insert({"metadata": {"ab_test_0": "variant_0"}})
24+
prompt = f"Generate a long email for {user}"
2425
else:
25-
trace_insert({'metadata': {'ab_test_0': 'variant_1'}})
26-
prompt = f'Generate a short email for {user}'
27-
28-
email = client.chat.completions.create(
29-
model="gpt-4o",
30-
messages=[
31-
{
32-
"role": "user",
33-
"content": prompt,
34-
}
35-
],
36-
).choices[0].message.content
26+
trace_insert({"metadata": {"ab_test_0": "variant_1"}})
27+
prompt = f"Generate a short email for {user}"
28+
29+
email = (
30+
client.chat.completions.create(
31+
model="gpt-4o",
32+
messages=[
33+
{
34+
"role": "user",
35+
"content": prompt,
36+
}
37+
],
38+
)
39+
.choices[0]
40+
.message.content
41+
)
3742

3843
return email, get_current_trace_id()
3944

4045

4146
def main():
4247
# generate email and get trace ID
43-
email, trace_id = generate_email('Max Mustermann')
48+
email, trace_id = generate_email("Max Mustermann")
4449

4550
# log user feedback on email using trace ID
4651
p.record_feedback(
@@ -51,5 +56,5 @@ def main():
5156
)
5257

5358

54-
if __name__ == '__main__':
59+
if __name__ == "__main__":
5560
main()

cookbook/cohere/trace_cohere.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import os
2+
3+
import cohere
4+
from dotenv import load_dotenv
5+
6+
from parea import Parea
7+
8+
load_dotenv()
9+
10+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
11+
co = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
12+
p.wrap_cohere_client(co)
13+
14+
response = co.chat(
15+
model="command-r-plus",
16+
preamble="You are a helpful assistant talking in JSON.",
17+
message="Generate a JSON describing a person, with the fields 'name' and 'age'",
18+
response_format={"type": "json_object"},
19+
)
20+
print(response)
21+
print("\n\n")
22+
23+
response = co.chat(message="Who discovered gravity?")
24+
print(response)
25+
print("\n\n")
26+
#
27+
docs = [
28+
"Carson City is the capital city of the American state of Nevada.",
29+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
30+
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
31+
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
32+
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
33+
]
34+
response = co.rerank(
35+
model="rerank-english-v3.0",
36+
query="What is the capital of the United States?",
37+
documents=docs,
38+
top_n=3,
39+
)
40+
print(response)
41+
print("\n\n")
42+
43+
44+
response = co.chat(
45+
model="command-r-plus",
46+
message="Where do the tallest penguins live?",
47+
documents=[
48+
{"title": "Tall penguins", "snippet": "Emperor penguins are the tallest."},
49+
{"title": "Penguin habitats", "snippet": "Emperor penguins only live in Antarctica."},
50+
{"title": "What are animals?", "snippet": "Animals are different from plants."},
51+
],
52+
)
53+
print(response)
54+
print("\n\n")
55+
56+
response = co.chat(model="command-r-plus", message="Who is more popular: Nsync or Backstreet Boys?", search_queries_only=True)
57+
print(response)
58+
print("\n\n")
59+
60+
response = co.chat(model="command-r-plus", message="Who is more popular: Nsync or Backstreet Boys?", connectors=[{"id": "web-search"}])
61+
print(response)
62+
print("\n\n")
63+
64+
for event in co.chat_stream(message="Who discovered gravity?"):
65+
print(event)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
3+
import cohere
4+
from dotenv import load_dotenv
5+
6+
from parea import Parea
7+
from parea.utils.universal_encoder import json_dumps
8+
9+
load_dotenv()
10+
11+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
12+
co = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
13+
p.wrap_cohere_client(co)
14+
15+
16+
def web_search(query: str) -> list[dict]:
17+
# your code for performing a web search goes here
18+
return [{"url": "https://en.wikipedia.org/wiki/Ontario", "text": "The capital of Ontario is Toronto, ..."}]
19+
20+
21+
web_search_tool = {
22+
"name": "web_search",
23+
"description": "performs a web search with the specified query",
24+
"parameter_definitions": {"query": {"description": "the query to look up", "type": "str", "required": True}},
25+
}
26+
27+
message = "Who is the mayor of the capital of Ontario?"
28+
model = "command-r-plus"
29+
30+
# STEP 2: Check what tools the model wants to use and how
31+
32+
res = co.chat(model=model, message=message, force_single_step=False, tools=[web_search_tool])
33+
34+
# as long as the model sends back tool_calls,
35+
# keep invoking tools and sending the results back to the model
36+
while res.tool_calls:
37+
print(res.text) # This will be an observation and a plan with next steps
38+
tool_results = []
39+
for call in res.tool_calls:
40+
# use the `web_search` tool with the search query the model sent back
41+
web_search_results = {"call": call, "outputs": web_search(call.parameters["query"])}
42+
tool_results.append(web_search_results)
43+
44+
# call chat again with tool results
45+
res = co.chat(model="command-r-plus", chat_history=res.chat_history, message="", force_single_step=False, tools=[web_search_tool], tool_results=tool_results)
46+
47+
print(res.text) # "The mayor of Toronto, the capital of Ontario is Olivia Chow"
48+
49+
50+
# tool descriptions that the model has access to
51+
tools = [
52+
{
53+
"name": "query_daily_sales_report",
54+
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
55+
"parameter_definitions": {"day": {"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", "type": "str", "required": True}},
56+
},
57+
{
58+
"name": "query_product_catalog",
59+
"description": "Connects to a a product catalog with information about all the products being sold, including categories, prices, and stock levels.",
60+
"parameter_definitions": {"category": {"description": "Retrieves product information data for all products in this category.", "type": "str", "required": True}},
61+
},
62+
]
63+
64+
# preamble containing instructions about the task and the desired style for the output.
65+
preamble = """
66+
## Task & Context
67+
You help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.
68+
69+
## Style Guide
70+
Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.
71+
"""
72+
73+
# user request
74+
message = "Can you provide a sales summary for 29th September 2023, and also give me some details about the products in the 'Electronics' category, for example their prices and stock levels?"
75+
76+
response = co.chat(message=message, force_single_step=True, tools=tools, preamble=preamble, model="command-r")
77+
print("The model recommends doing the following tool calls:")
78+
print("\n".join(str(tool_call) for tool_call in response.tool_calls))
79+
80+
tool_results = []
81+
# Iterate over the tool calls generated by the model
82+
for tool_call in response.tool_calls:
83+
# here is where you would call the tool recommended by the model, using the parameters recommended by the model
84+
output = {"output": f"functions_map[{tool_call.name}]({tool_call.parameters})"}
85+
# store the output in a list
86+
outputs = [output]
87+
# store your tool results in this format
88+
tool_results.append({"call": tool_call, "outputs": outputs})
89+
90+
91+
print("Tool results that will be fed back to the model in step 4:")
92+
print(json_dumps(tool_results, indent=4))
93+
94+
response = co.chat(message=message, tools=tools, tool_results=tool_results, preamble=preamble, model="command-r", temperature=0.3, force_single_step=True)
95+
96+
97+
print("Final answer:")
98+
print(response.text)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import List, Optional
2+
3+
import os
4+
from datetime import datetime
5+
6+
import cohere
7+
from dotenv import load_dotenv
8+
9+
from parea import Parea, trace, trace_insert
10+
11+
load_dotenv()
12+
13+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
14+
co = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
15+
p.wrap_cohere_client(co)
16+
17+
18+
def call_llm(message: str, chat_history: Optional[List[dict]] = None, system_message: str = "", model: str = "command-r-plus") -> str:
19+
return co.chat(
20+
model=model,
21+
preamble=system_message,
22+
chat_history=chat_history or [],
23+
message=message,
24+
).text
25+
26+
27+
@trace
28+
def argumentor(query: str, additional_description: str = "") -> str:
29+
return call_llm(
30+
system_message=f"""You are a debater making an argument on a topic. {additional_description}.
31+
The current time is {datetime.now().strftime("%Y-%m-%d")}""",
32+
message=f"The discussion topic is {query}",
33+
)
34+
35+
36+
@trace
37+
def critic(argument: str) -> str:
38+
return call_llm(
39+
system_message="""You are a critic.
40+
What unresolved questions or criticism do you have after reading the following argument?
41+
Provide a concise summary of your feedback.""",
42+
message=argument,
43+
)
44+
45+
46+
@trace
47+
def refiner(query: str, additional_description: str, argument: str, criticism: str) -> str:
48+
return call_llm(
49+
system_message=f"""You are a debater making an argument on a topic. {additional_description}.
50+
The current time is {datetime.now().strftime("%Y-%m-%d")}""",
51+
chat_history=[{"role": "USER", "message": f"""The discussion topic is {query}"""}, {"role": "CHATBOT", "message": argument}, {"role": "USER", "message": criticism}],
52+
message="Please generate a new argument that incorporates the feedback from the user.",
53+
)
54+
55+
56+
@trace
57+
def argument_chain(query: str, additional_description: str = "") -> str:
58+
trace_insert({"session_id": "cus_1234", "end_user_identifier": "user_1234"})
59+
argument = argumentor(query, additional_description)
60+
criticism = critic(argument)
61+
refined_argument = refiner(query, additional_description, argument, criticism)
62+
return refined_argument
63+
64+
65+
@trace(session_id="cus_1234", end_user_identifier="user_1234")
66+
def json_call() -> str:
67+
completion = co.chat(
68+
model="command-r-plus",
69+
preamble="You are a helpful assistant talking in JSON.",
70+
message="What are you?",
71+
response_format={"type": "json_object"},
72+
)
73+
return completion.text
74+
75+
76+
if __name__ == "__main__":
77+
result = argument_chain(
78+
"Whether sparkling wine is good for you.",
79+
additional_description="Provide a concise, few sentence argument on why sparkling wine is good for you.",
80+
)
81+
print(result)
82+
print(json_call())

parea/client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,13 @@ def wrap_anthropic_client(self, client: "Anthropic", integration: Optional[str]
118118
if integration:
119119
self._add_integration(integration)
120120

121+
def wrap_cohere_client(self, client: Union["cohere.Client", "cohere.AsyncClient"], integration: Optional[str] = None) -> None:
122+
from parea.wrapper.cohere.wrap_cohere import CohereClientWrapper
123+
124+
CohereClientWrapper().init(client=client)
125+
if integration:
126+
self._add_integration(integration)
127+
121128
def _add_integration(self, integration: str) -> None:
122129
self._client.add_integration(integration)
123130

0 commit comments

Comments
 (0)