Skip to content

Commit 2746240

Browse files
committed
feat(cohere): auto trace cohere
1 parent 5ed5c09 commit 2746240

File tree

5 files changed

+236
-97
lines changed

5 files changed

+236
-97
lines changed

cookbook/cohere/trace_cohere.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,26 @@
4040
print(response)
4141
print("\n\n")
4242

43+
44+
response = co.chat(
45+
model="command-r-plus",
46+
message="Where do the tallest penguins live?",
47+
documents=[
48+
{"title": "Tall penguins", "snippet": "Emperor penguins are the tallest."},
49+
{"title": "Penguin habitats", "snippet": "Emperor penguins only live in Antarctica."},
50+
{"title": "What are animals?", "snippet": "Animals are different from plants."},
51+
],
52+
)
53+
print(response)
54+
print("\n\n")
55+
56+
response = co.chat(model="command-r-plus", message="Who is more popular: Nsync or Backstreet Boys?", search_queries_only=True)
57+
print(response)
58+
print("\n\n")
59+
60+
response = co.chat(model="command-r-plus", message="Who is more popular: Nsync or Backstreet Boys?", connectors=[{"id": "web-search"}])
61+
print(response)
62+
print("\n\n")
63+
4364
for event in co.chat_stream(message="Who discovered gravity?"):
4465
print(event)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
3+
import cohere
4+
from dotenv import load_dotenv
5+
6+
from parea import Parea
7+
from parea.utils.universal_encoder import json_dumps
8+
9+
load_dotenv()
10+
11+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
12+
co = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
13+
p.wrap_cohere_client(co)
14+
15+
16+
def web_search(query: str) -> list[dict]:
17+
# your code for performing a web search goes here
18+
return [{"url": "https://en.wikipedia.org/wiki/Ontario", "text": "The capital of Ontario is Toronto, ..."}]
19+
20+
21+
web_search_tool = {
22+
"name": "web_search",
23+
"description": "performs a web search with the specified query",
24+
"parameter_definitions": {"query": {"description": "the query to look up", "type": "str", "required": True}},
25+
}
26+
27+
message = "Who is the mayor of the capital of Ontario?"
28+
model = "command-r-plus"
29+
30+
# STEP 2: Check what tools the model wants to use and how
31+
32+
res = co.chat(model=model, message=message, force_single_step=False, tools=[web_search_tool])
33+
34+
# as long as the model sends back tool_calls,
35+
# keep invoking tools and sending the results back to the model
36+
while res.tool_calls:
37+
print(res.text) # This will be an observation and a plan with next steps
38+
tool_results = []
39+
for call in res.tool_calls:
40+
# use the `web_search` tool with the search query the model sent back
41+
web_search_results = {"call": call, "outputs": web_search(call.parameters["query"])}
42+
tool_results.append(web_search_results)
43+
44+
# call chat again with tool results
45+
res = co.chat(model="command-r-plus", chat_history=res.chat_history, message="", force_single_step=False, tools=[web_search_tool], tool_results=tool_results)
46+
47+
print(res.text) # "The mayor of Toronto, the capital of Ontario is Olivia Chow"
48+
49+
50+
# tool descriptions that the model has access to
51+
tools = [
52+
{
53+
"name": "query_daily_sales_report",
54+
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
55+
"parameter_definitions": {"day": {"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", "type": "str", "required": True}},
56+
},
57+
{
58+
"name": "query_product_catalog",
59+
"description": "Connects to a a product catalog with information about all the products being sold, including categories, prices, and stock levels.",
60+
"parameter_definitions": {"category": {"description": "Retrieves product information data for all products in this category.", "type": "str", "required": True}},
61+
},
62+
]
63+
64+
# preamble containing instructions about the task and the desired style for the output.
65+
preamble = """
66+
## Task & Context
67+
You help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.
68+
69+
## Style Guide
70+
Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.
71+
"""
72+
73+
# user request
74+
message = "Can you provide a sales summary for 29th September 2023, and also give me some details about the products in the 'Electronics' category, for example their prices and stock levels?"
75+
76+
response = co.chat(message=message, force_single_step=True, tools=tools, preamble=preamble, model="command-r")
77+
print("The model recommends doing the following tool calls:")
78+
print("\n".join(str(tool_call) for tool_call in response.tool_calls))
79+
80+
tool_results = []
81+
# Iterate over the tool calls generated by the model
82+
for tool_call in response.tool_calls:
83+
# here is where you would call the tool recommended by the model, using the parameters recommended by the model
84+
output = {"output": f"functions_map[{tool_call.name}]({tool_call.parameters})"}
85+
# store the output in a list
86+
outputs = [output]
87+
# store your tool results in this format
88+
tool_results.append({"call": tool_call, "outputs": outputs})
89+
90+
91+
print("Tool results that will be fed back to the model in step 4:")
92+
print(json_dumps(tool_results, indent=4))
93+
94+
response = co.chat(message=message, tools=tools, tool_results=tool_results, preamble=preamble, model="command-r", temperature=0.3, force_single_step=True)
95+
96+
97+
print("Final answer:")
98+
print(response.text)

cookbook/langchain/trace_class_call_method.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@
1212

1313
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
1414

15-
from langsmith.evaluation import LangChainStringEvaluator
16-
17-
qa_evaluator = [LangChainStringEvaluator("cot_qa")]
18-
1915

2016
class LangChainModule:
2117
handler = PareaAILangchainTracer()

parea/wrapper/cohere/helpers.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from typing import Any, Dict, List, Optional, Tuple, Union
2+
3+
import functools
4+
5+
import cohere
6+
from attrs import asdict, define
7+
from cohere import ApiMetaBilledUnits, NonStreamedChatResponse, RerankResponse
8+
9+
from parea.constants import COHERE_MODEL_INFO, COHERE_SEARCH_MODELS
10+
from parea.schemas import Message, Role
11+
from parea.utils.universal_encoder import json_dumps
12+
13+
DEFAULT_MODEL = "command-r-plus"
14+
DEFAULT_TEMPERATURE = 0.3
15+
DEFAULT_P = 0.75
16+
17+
18+
@define
19+
class CohereOutput:
20+
text: Optional[str] = None
21+
citations: Optional[str] = None
22+
documents: Optional[str] = None
23+
search_queries: Optional[str] = None
24+
search_results: Optional[str] = None
25+
26+
27+
def chat_history_to_messages(result: NonStreamedChatResponse, **kwargs) -> list[Message]:
28+
messages: list[Message] = []
29+
if sys_message := kwargs.get("preamble", ""):
30+
messages.append(Message(content=sys_message, role=Role.system))
31+
if history := kwargs.get("chat_history", []):
32+
messages.extend(to_messages(history))
33+
34+
messages.extend(to_messages([m.dict() for m in result.chat_history]))
35+
return messages
36+
37+
38+
def to_messages(chat_history: List[Union[Dict, cohere.Message]]) -> List[Message]:
39+
role_map = {"USER": Role.user, "CHATBOT": Role.assistant, "SYSTEM": Role.system, "TOOL": Role.tool}
40+
41+
def process_message(message: Union[Dict, cohere.Message]) -> Message:
42+
if isinstance(message, dict):
43+
role = role_map.get(message["role"], Role.user)
44+
content = message.get("message", "")
45+
tool_calls = message.get("tool_calls") or message.get("tool_results")
46+
else: # cohere.Message
47+
role = role_map.get(message.role, Role.user)
48+
content = "" if role == Role.tool else message.message
49+
tool_calls = getattr(message, "tool_calls", None) or getattr(message, "tool_results", None)
50+
51+
if tool_calls:
52+
tc = json_dumps([t.dict() if hasattr(t, "dict") else t for t in tool_calls])
53+
content = tc if role == Role.tool or not content else json_dumps({"message": content, "tool_calls": tc})
54+
55+
return Message(content=content, role=role)
56+
57+
return list(map(process_message, chat_history))
58+
59+
60+
@functools.lru_cache(maxsize=128)
61+
def compute_cost(prompt_tokens: int, completion_tokens: int, search_units: int, is_search_model: bool, model: str) -> float:
62+
cost_per_token = COHERE_MODEL_INFO.get(model, {"prompt": 0, "completion": 0})
63+
cost = ((prompt_tokens * cost_per_token["prompt"]) + (completion_tokens * cost_per_token["completion"])) / 1_000_000
64+
if is_search_model:
65+
cost += search_units * cost_per_token.get("search", 0) / 1_000
66+
cost = round(cost, 10)
67+
return cost
68+
69+
70+
def get_usage_stats(result: Optional[NonStreamedChatResponse | RerankResponse], model: str) -> Tuple[int, int, float]:
71+
bu: Optional[ApiMetaBilledUnits] = result.meta.billed_units if result else None
72+
if not bu:
73+
return 0, 0, 0.0
74+
prompt_tokens = bu.input_tokens or 0
75+
completion_tokens = bu.output_tokens or 0
76+
search_units = bu.search_units or 0
77+
is_search_model: bool = model in COHERE_SEARCH_MODELS
78+
cost = compute_cost(prompt_tokens, completion_tokens, search_units, is_search_model, model)
79+
return prompt_tokens, completion_tokens, cost
80+
81+
82+
def get_output(result: Optional[NonStreamedChatResponse | RerankResponse]) -> str:
83+
if not result:
84+
return ""
85+
86+
if isinstance(result, RerankResponse):
87+
output = CohereOutput(documents=cohere_json_list(result.results) if result.results else None)
88+
return json_dumps(asdict(output))
89+
90+
text = result.text or cohere_json_list(result.tool_calls)
91+
output = CohereOutput(
92+
text=text,
93+
citations=cohere_json_list(result.citations) if result.citations else None,
94+
documents=cohere_json_list(result.documents) if result.documents else None,
95+
search_queries=cohere_json_list(result.search_queries) if result.search_queries else None,
96+
search_results=cohere_json_list(result.search_results) if result.search_results else None,
97+
)
98+
return json_dumps(asdict(output))
99+
100+
101+
def cohere_json_list(obj: Any) -> str:
102+
out = []
103+
for o in obj or []:
104+
if isinstance(o, dict):
105+
out.append(o)
106+
else:
107+
out.append(o.dict())
108+
return json_dumps(out)

parea/wrapper/cohere/wrap_cohere.py

Lines changed: 9 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,17 @@
1111
import traceback
1212

1313
import cohere
14-
from attrs import asdict, define
15-
from cohere import ApiMetaBilledUnits, NonStreamedChatResponse, RerankResponse
14+
from cohere import NonStreamedChatResponse, RerankResponse
1615

17-
from parea.constants import COHERE_MODEL_INFO, COHERE_SEARCH_MODELS, PAREA_OS_ENV_EXPERIMENT_UUID
16+
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
1817
from parea.helpers import gen_trace_id, is_logging_disabled, timezone_aware_now
19-
from parea.schemas import LLMInputs, Message, ModelParams, Role, TraceLog, UpdateTraceScenario
18+
from parea.schemas import LLMInputs, ModelParams, TraceLog, UpdateTraceScenario
2019
from parea.utils.trace_utils import execution_order_counters, fill_trace_data, logger_record_log, trace_context, trace_data
2120
from parea.utils.universal_encoder import json_dumps
21+
from parea.wrapper.cohere.helpers import DEFAULT_MODEL, DEFAULT_P, DEFAULT_TEMPERATURE, chat_history_to_messages, get_output, get_usage_stats
2222

2323
logger = logging.getLogger()
2424

25-
DEFAULT_MODEL = "command-r-plus"
26-
DEFAULT_TEMPERATURE = 0.3
27-
DEFAULT_P = 0.75
28-
29-
30-
@define
31-
class CohereOutput:
32-
text: Optional[str] = None
33-
citations: Optional[str] = None
34-
documents: Optional[str] = None
35-
search_queries: Optional[str] = None
36-
search_results: Optional[str] = None
37-
3825

3926
class CohereClientWrapper:
4027
@staticmethod
@@ -201,7 +188,7 @@ def _fill_llm_config(trace_id: str, result: Optional[NonStreamedChatResponse | R
201188
"""
202189
try:
203190
model = kwargs.get("model", DEFAULT_MODEL)
204-
tools = kwargs.get("tools")
191+
tools = kwargs.get("tools", None)
205192
configuration = LLMInputs(
206193
model=model,
207194
provider="cohere",
@@ -213,13 +200,13 @@ def _fill_llm_config(trace_id: str, result: Optional[NonStreamedChatResponse | R
213200
max_length=kwargs.get("max_tokens"),
214201
response_format=kwargs.get("response_format"),
215202
),
216-
messages=CohereClientWrapper._chat_history_to_messages(result, **kwargs) if isinstance(result, NonStreamedChatResponse) else None,
217-
functions=json_dumps(tools) if tools else None,
203+
messages=chat_history_to_messages(result, **kwargs) if isinstance(result, NonStreamedChatResponse) else None,
204+
functions=tools,
218205
)
219-
prompt_tokens, completion_tokens, cost = CohereClientWrapper._get_usage_stats(result, model)
206+
prompt_tokens, completion_tokens, cost = get_usage_stats(result, model)
220207
data = {
221208
"configuration": configuration,
222-
"output": CohereClientWrapper._get_output(result),
209+
"output": get_output(result),
223210
"input_tokens": prompt_tokens,
224211
"output_tokens": completion_tokens,
225212
"total_tokens": prompt_tokens + completion_tokens,
@@ -230,77 +217,6 @@ def _fill_llm_config(trace_id: str, result: Optional[NonStreamedChatResponse | R
230217
logger.debug(f"Error occurred filling LLM config for trace {trace_id}, {e}", exc_info=True)
231218
fill_trace_data(trace_id, {"error": traceback.format_exc()}, UpdateTraceScenario.ERROR)
232219

233-
@staticmethod
234-
def _chat_history_to_messages(result: NonStreamedChatResponse, **kwargs) -> list[Message]:
235-
messages: list[Message] = []
236-
if sys_message := kwargs.get("preamble", ""):
237-
messages.append(Message(content=sys_message, role=Role.system))
238-
if history := kwargs.get("chat_history", []):
239-
messages.extend(CohereClientWrapper._to_messages(history))
240-
241-
messages.extend(CohereClientWrapper._to_messages([m.dict() for m in result.chat_history]))
242-
return messages
243-
244-
@staticmethod
245-
def _to_messages(chat_history: list[dict]) -> list[Message]:
246-
messages: list[Message] = []
247-
for message in chat_history:
248-
if message["role"] == "USER":
249-
messages.append(Message(content=message["message"], role=Role.user))
250-
elif message["role"] == "CHATBOT":
251-
messages.append(Message(content=message["message"], role=Role.assistant))
252-
elif message["role"] == "SYSTEM":
253-
messages.append(Message(content=message["message"], role=Role.system))
254-
elif message["role"] == "TOOL":
255-
messages.append(Message(content=json_dumps(message["tool_calls"]), role=Role.tool))
256-
257-
return messages
258-
259-
@staticmethod
260-
@functools.lru_cache(maxsize=128)
261-
def _compute_cost(prompt_tokens: int, completion_tokens: int, search_units: int, is_search_model: bool, model: str) -> float:
262-
cost_per_token = COHERE_MODEL_INFO.get(model, {"prompt": 0, "completion": 0})
263-
cost = ((prompt_tokens * cost_per_token["prompt"]) + (completion_tokens * cost_per_token["completion"])) / 1_000_000
264-
if is_search_model:
265-
cost += search_units * cost_per_token.get("search", 0) / 1_000
266-
cost = round(cost, 10)
267-
return cost
268-
269-
@staticmethod
270-
def _get_usage_stats(result: Optional[NonStreamedChatResponse | RerankResponse], model: str) -> Tuple[int, int, float]:
271-
bu: Optional[ApiMetaBilledUnits] = result.meta.billed_units if result else None
272-
if not bu:
273-
return 0, 0, 0.0
274-
prompt_tokens = bu.input_tokens or 0
275-
completion_tokens = bu.output_tokens or 0
276-
search_units = bu.search_units or 0
277-
is_search_model: bool = model in COHERE_SEARCH_MODELS
278-
cost = CohereClientWrapper._compute_cost(prompt_tokens, completion_tokens, search_units, is_search_model, model)
279-
return prompt_tokens, completion_tokens, cost
280-
281-
@staticmethod
282-
def _get_output(result: Optional[NonStreamedChatResponse | RerankResponse]) -> str:
283-
if not result:
284-
return ""
285-
286-
if isinstance(result, RerankResponse):
287-
output = CohereOutput(documents=CohereClientWrapper._cohere_json_list(result.results) if result.results else None)
288-
return json_dumps(asdict(output))
289-
290-
text = result.text or CohereClientWrapper._cohere_json_list(result.tool_calls)
291-
output = CohereOutput(
292-
text=text,
293-
citations=CohereClientWrapper._cohere_json_list(result.citations) if result.citations else None,
294-
documents=CohereClientWrapper._cohere_json_list(result.documents) if result.documents else None,
295-
search_queries=CohereClientWrapper._cohere_json_list(result.search_queries) if result.search_queries else None,
296-
search_results=CohereClientWrapper._cohere_json_list(result.search_results) if result.search_results else None,
297-
)
298-
return json_dumps(asdict(output))
299-
300-
@staticmethod
301-
def _cohere_json_list(obj: Any) -> str:
302-
return json_dumps([o.dict() for o in obj])
303-
304220
@staticmethod
305221
def init(client: Union[cohere.Client, cohere.AsyncClient]) -> None:
306222
"""

0 commit comments

Comments
 (0)