Skip to content

Commit 5ed5c09

Browse files
committed
feat(cohere): auto trace cohere
1 parent 75b69c3 commit 5ed5c09

File tree

12 files changed

+2530
-1512
lines changed

12 files changed

+2530
-1512
lines changed

cookbook/ab_testing.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
from typing import Tuple
2+
13
import os
24
import random
3-
from typing import Tuple
45

56
from openai import OpenAI
67

7-
from parea import trace, trace_insert, Parea, get_current_trace_id
8+
from parea import Parea, get_current_trace_id, trace, trace_insert
89
from parea.schemas import FeedbackRequest
910

1011
client = OpenAI()
1112
# instantiate Parea client
12-
p = Parea(api_key=os.getenv('PAREA_API_KEY'))
13+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
1314
# wrap OpenAI client to trace calls
1415
p.wrap_openai_client(client)
1516

@@ -19,28 +20,32 @@ def generate_email(user: str) -> Tuple[str, str]:
1920
"""Randomly chooses a prompt to perform an A/B test for generating email. Returns the email and the trace ID.
2021
The latter is used to tie-back the collected feedback from the user."""
2122
if random.random() < 0.5:
22-
trace_insert({'metadata': {'ab_test_0': 'variant_0'}})
23-
prompt = f'Generate a long email for {user}'
23+
trace_insert({"metadata": {"ab_test_0": "variant_0"}})
24+
prompt = f"Generate a long email for {user}"
2425
else:
25-
trace_insert({'metadata': {'ab_test_0': 'variant_1'}})
26-
prompt = f'Generate a short email for {user}'
27-
28-
email = client.chat.completions.create(
29-
model="gpt-4o",
30-
messages=[
31-
{
32-
"role": "user",
33-
"content": prompt,
34-
}
35-
],
36-
).choices[0].message.content
26+
trace_insert({"metadata": {"ab_test_0": "variant_1"}})
27+
prompt = f"Generate a short email for {user}"
28+
29+
email = (
30+
client.chat.completions.create(
31+
model="gpt-4o",
32+
messages=[
33+
{
34+
"role": "user",
35+
"content": prompt,
36+
}
37+
],
38+
)
39+
.choices[0]
40+
.message.content
41+
)
3742

3843
return email, get_current_trace_id()
3944

4045

4146
def main():
4247
# generate email and get trace ID
43-
email, trace_id = generate_email('Max Mustermann')
48+
email, trace_id = generate_email("Max Mustermann")
4449

4550
# log user feedback on email using trace ID
4651
p.record_feedback(
@@ -51,5 +56,5 @@ def main():
5156
)
5257

5358

54-
if __name__ == '__main__':
59+
if __name__ == "__main__":
5560
main()

cookbook/cohere/trace_cohere.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
3+
import cohere
4+
from dotenv import load_dotenv
5+
6+
from parea import Parea
7+
8+
load_dotenv()
9+
10+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
11+
co = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
12+
p.wrap_cohere_client(co)
13+
14+
response = co.chat(
15+
model="command-r-plus",
16+
preamble="You are a helpful assistant talking in JSON.",
17+
message="Generate a JSON describing a person, with the fields 'name' and 'age'",
18+
response_format={"type": "json_object"},
19+
)
20+
print(response)
21+
print("\n\n")
22+
23+
response = co.chat(message="Who discovered gravity?")
24+
print(response)
25+
print("\n\n")
26+
#
27+
docs = [
28+
"Carson City is the capital city of the American state of Nevada.",
29+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
30+
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
31+
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
32+
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
33+
]
34+
response = co.rerank(
35+
model="rerank-english-v3.0",
36+
query="What is the capital of the United States?",
37+
documents=docs,
38+
top_n=3,
39+
)
40+
print(response)
41+
print("\n\n")
42+
43+
for event in co.chat_stream(message="Who discovered gravity?"):
44+
print(event)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import List, Optional
2+
3+
import os
4+
from datetime import datetime
5+
6+
import cohere
7+
from dotenv import load_dotenv
8+
9+
from parea import Parea, trace, trace_insert
10+
11+
load_dotenv()
12+
13+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
14+
co = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
15+
p.wrap_cohere_client(co)
16+
17+
18+
def call_llm(message: str, chat_history: Optional[List[dict]] = None, system_message: str = "", model: str = "command-r-plus") -> str:
19+
return co.chat(
20+
model=model,
21+
preamble=system_message,
22+
chat_history=chat_history or [],
23+
message=message,
24+
).text
25+
26+
27+
@trace
28+
def argumentor(query: str, additional_description: str = "") -> str:
29+
return call_llm(
30+
system_message=f"""You are a debater making an argument on a topic. {additional_description}.
31+
The current time is {datetime.now().strftime("%Y-%m-%d")}""",
32+
message=f"The discussion topic is {query}",
33+
)
34+
35+
36+
@trace
37+
def critic(argument: str) -> str:
38+
return call_llm(
39+
system_message="""You are a critic.
40+
What unresolved questions or criticism do you have after reading the following argument?
41+
Provide a concise summary of your feedback.""",
42+
message=argument,
43+
)
44+
45+
46+
@trace
47+
def refiner(query: str, additional_description: str, argument: str, criticism: str) -> str:
48+
return call_llm(
49+
system_message=f"""You are a debater making an argument on a topic. {additional_description}.
50+
The current time is {datetime.now().strftime("%Y-%m-%d")}""",
51+
chat_history=[{"role": "USER", "message": f"""The discussion topic is {query}"""}, {"role": "CHATBOT", "message": argument}, {"role": "USER", "message": criticism}],
52+
message="Please generate a new argument that incorporates the feedback from the user.",
53+
)
54+
55+
56+
@trace
57+
def argument_chain(query: str, additional_description: str = "") -> str:
58+
trace_insert({"session_id": "cus_1234", "end_user_identifier": "user_1234"})
59+
argument = argumentor(query, additional_description)
60+
criticism = critic(argument)
61+
refined_argument = refiner(query, additional_description, argument, criticism)
62+
return refined_argument
63+
64+
65+
@trace(session_id="cus_1234", end_user_identifier="user_1234")
66+
def json_call() -> str:
67+
completion = co.chat(
68+
model="command-r-plus",
69+
preamble="You are a helpful assistant talking in JSON.",
70+
message="What are you?",
71+
response_format={"type": "json_object"},
72+
)
73+
return completion.text
74+
75+
76+
if __name__ == "__main__":
77+
result = argument_chain(
78+
"Whether sparkling wine is good for you.",
79+
additional_description="Provide a concise, few sentence argument on why sparkling wine is good for you.",
80+
)
81+
# print(result)
82+
# print(json_call())

cookbook/langchain/trace_class_call_method.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
1414

15+
from langsmith.evaluation import LangChainStringEvaluator
16+
17+
qa_evaluator = [LangChainStringEvaluator("cot_qa")]
18+
1519

1620
class LangChainModule:
1721
handler = PareaAILangchainTracer()

parea/client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,13 @@ def wrap_anthropic_client(self, client: "Anthropic", integration: Optional[str]
118118
if integration:
119119
self._add_integration(integration)
120120

121+
def wrap_cohere_client(self, client: Union["cohere.Client", "cohere.AsyncClient"], integration: Optional[str] = None) -> None:
122+
from parea.wrapper.cohere.wrap_cohere import CohereClientWrapper
123+
124+
CohereClientWrapper().init(client=client)
125+
if integration:
126+
self._add_integration(integration)
127+
121128
def _add_integration(self, integration: str) -> None:
122129
self._client.add_integration(integration)
123130

parea/constants.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,78 @@ def str2bool(v):
371371
"completion": 15.00,
372372
},
373373
}
374-
ALL_NON_AZURE_MODELS_INFO = {**OPENAI_MODEL_INFO, **ANTHROPIC_MODEL_INFO}
374+
COHERE_MODEL_INFO: Dict[str, Dict[str, Union[float, int, Dict[str, int]]]] = {
375+
"command-r-plus": {
376+
"prompt": 3.0,
377+
"completion": 15.0,
378+
"token_limit": {"max_completion_tokens": 4096, "max_prompt_tokens": 128000},
379+
},
380+
"command-r": {
381+
"prompt": 0.5,
382+
"completion": 1.5,
383+
"token_limit": {"max_completion_tokens": 4096, "max_prompt_tokens": 128000},
384+
},
385+
"command": {
386+
"prompt": 1.0,
387+
"completion": 2.0,
388+
"token_limit": {"max_completion_tokens": 4096, "max_prompt_tokens": 4096},
389+
},
390+
"command-nightly": {
391+
"prompt": 1.0,
392+
"completion": 2.0,
393+
"token_limit": {"max_completion_tokens": 128000, "max_prompt_tokens": 128000},
394+
},
395+
"command-light": {
396+
"prompt": 0.3,
397+
"completion": 0.6,
398+
"token_limit": {"max_completion_tokens": 4096, "max_prompt_tokens": 4096},
399+
},
400+
"command-light-nightly": {
401+
"prompt": 0.3,
402+
"completion": 0.6,
403+
"token_limit": {"max_completion_tokens": 4096, "max_prompt_tokens": 4096},
404+
},
405+
"c4ai-aya-23": {
406+
"prompt": 0.0,
407+
"completion": 0.0,
408+
"token_limit": {"max_completion_tokens": 8192, "max_prompt_tokens": 8192},
409+
},
410+
"rerank-english-v3.0": {
411+
"prompt": 0,
412+
"completion": 0,
413+
# $ per 1K
414+
"search": 2.0,
415+
"token_limit": {"max_completion_tokens": 4096, "max_prompt_tokens": 4096},
416+
},
417+
"rerank-multilingual-v3.0": {
418+
"prompt": 0,
419+
"completion": 0,
420+
# $ per 1K
421+
"search": 2.0,
422+
"token_limit": {"max_completion_tokens": 4096, "max_prompt_tokens": 4096},
423+
},
424+
"rerank-english-v2.0": {
425+
"prompt": 0,
426+
"completion": 0,
427+
# $ per 1K
428+
"search": 1.0,
429+
"token_limit": {"max_completion_tokens": 512, "max_prompt_tokens": 512},
430+
},
431+
"rerank-multilingual-v2.0": {
432+
"prompt": 0,
433+
"completion": 0,
434+
# $ per 1K
435+
"search": 1.0,
436+
"token_limit": {"max_completion_tokens": 512, "max_prompt_tokens": 512},
437+
},
438+
}
439+
COHERE_SEARCH_MODELS: set[str] = {
440+
"rerank-english-v3.0",
441+
"rerank-multilingual-v3.0",
442+
"rerank-english-v2.0",
443+
"rerank-multilingual-v2.0",
444+
}
445+
ALL_NON_AZURE_MODELS_INFO = {**OPENAI_MODEL_INFO, **ANTHROPIC_MODEL_INFO, **COHERE_MODEL_INFO}
375446

376447
NOUNS = (
377448
"abac",

parea/experiment/experiment.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from attrs import define, field
1414
from tqdm import tqdm
15-
from tqdm.asyncio import tqdm_asyncio
1615

1716
from parea import Parea
1817
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
@@ -68,7 +67,7 @@ def apply_dataset_eval(dataset_level_evals: List[Callable]) -> List[EvaluationRe
6867
try:
6968
result = dataset_level_eval(root_traces)
7069
except Exception as e:
71-
logger.exception(f"Error occurred calling dataset level eval function '{dataset_level_eval.__name__}': {e}", exc_info=e)
70+
logger.error(f"Error occurred calling dataset level eval function '{dataset_level_eval.__name__}': {e}", exc_info=e)
7271
continue
7372
if result is None:
7473
continue

parea/utils/trace_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def cleanup_trace(trace_id: str, start_time: datetime, context_token: contextvar
245245
output = access_output_of_func(output)
246246
output_for_eval_metrics = json_dumps(output)
247247
except Exception as e:
248-
logger.exception(f"Error accessing output of func with output: {output}. Error: {e}", exc_info=e)
248+
logger.error(f"Error accessing output of func with output: {output}. Error: {e}", exc_info=e)
249249
trace_data.get()[trace_id].output_for_eval_metrics = output_for_eval_metrics
250250

251251
thread_eval_funcs_then_log(trace_id, eval_funcs)
@@ -263,7 +263,7 @@ async def async_wrapper(*args, **kwargs):
263263
fill_trace_data(trace_id, {"result": result, "output_as_list": output_as_list, "eval_funcs_names": eval_funcs_names}, UpdateTraceScenario.RESULT)
264264
return result
265265
except Exception as e:
266-
logger.exception(f"Error occurred in function {func.__name__}, {e}")
266+
logger.error(f"Error occurred in function {func.__name__}, {e}")
267267
fill_trace_data(trace_id, {"error": traceback.format_exc()}, UpdateTraceScenario.ERROR)
268268
raise e
269269
finally:
@@ -283,7 +283,7 @@ def wrapper(*args, **kwargs):
283283
fill_trace_data(trace_id, {"result": result, "output_as_list": output_as_list, "eval_funcs_names": eval_funcs_names}, UpdateTraceScenario.RESULT)
284284
return result
285285
except Exception as e:
286-
logger.exception(f"Error occurred in function {func.__name__}, {e}")
286+
logger.error(f"Error occurred in function {func.__name__}, {e}")
287287
fill_trace_data(trace_id, {"error": traceback.format_exc()}, UpdateTraceScenario.ERROR)
288288
raise e
289289
finally:
@@ -326,7 +326,7 @@ def call_eval_funcs_then_log(trace_id: str, eval_funcs: List[Callable] = None):
326326
elif score is not None:
327327
scores.append(EvaluationResult(name=func.__name__, score=score))
328328
except Exception as e:
329-
logger.exception(f"Error occurred calling evaluation function '{func.__name__}', {e}", exc_info=e)
329+
logger.error(f"Error occurred calling evaluation function '{func.__name__}', {e}", exc_info=e)
330330
trace_data.get()[trace_id].scores = scores
331331
thread_ids_running_evals.get().remove(trace_id)
332332

0 commit comments

Comments
 (0)