diff --git a/agentops/partners/langchain_callback_handler.py b/agentops/partners/langchain_callback_handler.py index 3803b4888..768097dcf 100644 --- a/agentops/partners/langchain_callback_handler.py +++ b/agentops/partners/langchain_callback_handler.py @@ -9,14 +9,12 @@ from langchain_core.agents import AgentFinish, AgentAction from langchain_core.documents import Document from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult -from langchain.callbacks.base import BaseCallbackHandler, AsyncCallbackHandler +from langchain_core.callbacks.base import BaseCallbackHandler, AsyncCallbackHandler from langchain_core.messages import BaseMessage - from agentops import Client as AOClient from agentops import ActionEvent, LLMEvent, ToolEvent, ErrorEvent from agentops.helpers import get_ISO_time, debug_print_function_params - from ..log_config import logger @@ -46,7 +44,7 @@ def __init__( endpoint: Optional[str] = None, max_wait_time: Optional[int] = None, max_queue_size: Optional[int] = None, - default_tags: Optional[List[str]] = None, + default_tags: List[str] = ["langchain", "sync"], ): logging_level = os.getenv("AGENTOPS_LOGGING_LEVEL") log_levels = { @@ -93,14 +91,67 @@ def on_llm_start( ) -> Any: self.events.llm[str(run_id)] = LLMEvent( params={ - **serialized, - **({} if metadata is None else metadata), - **kwargs, - }, # TODO: params is inconsistent, in ToolEvent we put it in logs + "serialized": serialized, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + "run_id": run_id, + "parent_run_id": parent_run_id, + "tags": tags, + }, model=get_model_from_kwargs(kwargs), prompt=prompts[0], ) + @debug_print_function_params + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running.""" + parsed_messages = [ + {"role": message.type, "content": message.content} + for message in messages[0] + if message.type in ["system", "human"] + ] + + action_event = ActionEvent( + params={ + "serialized": serialized, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + "run_id": run_id, + "parent_run_id": parent_run_id, + "tags": tags, + "messages": parsed_messages, + }, + action_type="on_chat_model_start", + ) + self.ao_client.record(action_event) + + # Initialize LLMEvent here since on_llm_start isn't called for chat models + self.events.llm[str(run_id)] = LLMEvent( + params={ + "serialized": serialized, + "messages": messages, + "run_id": run_id, + "parent_run_id": parent_run_id, + "tags": tags, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + }, + model=get_model_from_kwargs(kwargs), + prompt=parsed_messages, + completion="", + returns={}, + ) + @debug_print_function_params def on_llm_error( self, @@ -111,7 +162,11 @@ def on_llm_error( **kwargs: Any, ) -> Any: llm_event: LLMEvent = self.events.llm[str(run_id)] - error_event = ErrorEvent(trigger_event=llm_event, exception=error) + error_event = ErrorEvent( + trigger_event=llm_event, + exception=error, + details={"run_id": run_id, "parent_run_id": parent_run_id, "kwargs": kwargs}, + ) self.ao_client.record(error_event) @debug_print_function_params @@ -124,25 +179,34 @@ def on_llm_end( **kwargs: Any, ) -> Any: llm_event: LLMEvent = self.events.llm[str(run_id)] - llm_event.returns = { - "content": response.generations[0][0].text, - "generations": response.generations, - } + llm_event.returns = response llm_event.end_timestamp = get_ISO_time() - llm_event.completion = response.generations[0][0].text - if response.llm_output is not None: - llm_event.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"] - llm_event.completion_tokens = response.llm_output["token_usage"]["completion_tokens"] if len(response.generations) == 0: - # TODO: more descriptive error error_event = ErrorEvent( trigger_event=self.events.llm[str(run_id)], error_type="NoGenerations", - details="on_llm_end: No generations", + details={"run_id": run_id, "parent_run_id": parent_run_id, "kwargs": kwargs}, ) self.ao_client.record(error_event) else: + for generation in response.generations[0]: + if ( + generation.message.type == "AIMessage" + and generation.text + and llm_event.completion != generation.text + ): + llm_event.completion = generation.text + elif ( + generation.message.type == "AIMessageChunk" + and generation.message.content + and llm_event.completion != generation.message.content + ): + llm_event.completion += generation.message.content + + if response.llm_output is not None: + llm_event.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"] + llm_event.completion_tokens = response.llm_output["token_usage"]["completion_tokens"] self.ao_client.record(llm_event) @debug_print_function_params @@ -157,18 +221,24 @@ def on_chain_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - try: - self.events.chain[str(run_id)] = ActionEvent( - params={ - **serialized, - **inputs, - **({} if metadata is None else metadata), - **kwargs, - }, - action_type="chain", - ) - except Exception as e: - logger.warning(e) + # Initialize with empty dicts if None + serialized = serialized or {} + inputs = inputs or {} + metadata = metadata or {} + + self.events.chain[str(run_id)] = ActionEvent( + params={ + "serialized": serialized, + "inputs": inputs, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + "run_id": run_id, + "parent_run_id": parent_run_id, + "tags": tags, + **kwargs, + }, + action_type="on_chain_start", + ) @debug_print_function_params def on_chain_end( @@ -193,7 +263,11 @@ def on_chain_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - action_event: ActionEvent = self.events.chain[str(run_id)] + # Create a new ActionEvent if one doesn't exist for this run_id + if str(run_id) not in self.events.chain: + self.events.chain[str(run_id)] = ActionEvent(params=kwargs, action_type="on_chain_error") + + action_event = self.events.chain[str(run_id)] error_event = ErrorEvent(trigger_event=action_event, exception=error) self.ao_client.record(error_event) @@ -211,14 +285,16 @@ def on_tool_start( **kwargs: Any, ) -> Any: self.events.tool[str(run_id)] = ToolEvent( - params=input_str if inputs is None else inputs, - name=serialized["name"], + params=inputs, + name=serialized.get("name"), logs={ - **serialized, + "serialized": serialized, + "input_str": input_str, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + "run_id": run_id, + "parent_run_id": parent_run_id, "tags": tags, - **({} if metadata is None else metadata), - **({} if inputs is None else inputs), - **kwargs, }, ) @@ -235,8 +311,6 @@ def on_tool_end( tool_event.end_timestamp = get_ISO_time() tool_event.returns = output - # Tools are capable of failing `on_tool_end` quietly. - # This is a workaround to make sure we can log it as an error. if kwargs.get("name") == "_Exception": error_event = ErrorEvent( trigger_event=tool_event, @@ -271,15 +345,18 @@ def on_retriever_start( tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> None: + ) -> Any: self.events.retriever[str(run_id)] = ActionEvent( params={ - **serialized, + "serialized": serialized, "query": query, - **({} if metadata is None else metadata), - **kwargs, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + "run_id": run_id, + "parent_run_id": parent_run_id, + "tags": tags, }, - action_type="retriever", + action_type="on_retriever_start", ) @debug_print_function_params @@ -291,9 +368,9 @@ def on_retriever_end( parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, - ) -> None: + ) -> Any: action_event: ActionEvent = self.events.retriever[str(run_id)] - action_event.logs = documents # TODO: Adding this. Might want to add elsewhere e.g. params + action_event.returns = documents action_event.end_timestamp = get_ISO_time() self.ao_client.record(action_event) @@ -306,7 +383,7 @@ def on_retriever_error( parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, - ) -> None: + ) -> Any: action_event: ActionEvent = self.events.retriever[str(run_id)] error_event = ErrorEvent(trigger_event=action_event, exception=error) self.ao_client.record(error_event) @@ -320,7 +397,9 @@ def on_agent_action( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - self.agent_actions[run_id].append(ActionEvent(params={"action": action, **kwargs}, action_type="agent")) + self.agent_actions[run_id].append( + ActionEvent(params={"action": action, **kwargs}, action_type="on_agent_action") + ) @debug_print_function_params def on_agent_finish( @@ -331,15 +410,10 @@ def on_agent_finish( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - # Need to attach finish to some on_agent_action so just choosing the last one self.agent_actions[run_id][-1].returns = finish.to_json() - for agentAction in self.agent_actions[run_id]: self.ao_client.record(agentAction) - # TODO: Create a way for the end user to set this based on their conditions - # self.ao_client.end_session("Success") #TODO: calling end_session here causes "No current session" - @debug_print_function_params def on_retry( self, @@ -350,16 +424,35 @@ def on_retry( **kwargs: Any, ) -> Any: action_event = ActionEvent( - params={**kwargs}, - returns=str(retry_state), - action_type="retry", - # result="Indeterminate" # TODO: currently have no way of recording Indeterminate + params={ + "retry_state": retry_state, + "run_id": run_id, + "parent_run_id": parent_run_id, + "kwargs": kwargs, + }, + action_type="on_retry", ) self.ao_client.record(action_event) - @property - def session_id(self): - raise DeprecationWarning("session_id is deprecated in favor of current_session_ids") + @debug_print_function_params + def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run on new LLM token. Only available when streaming is enabled.""" + if str(run_id) not in self.events.llm: + self.events.llm[str(run_id)] = LLMEvent(params=kwargs) + self.events.llm[str(run_id)].completion = "" + + llm_event = self.events.llm[str(run_id)] + # Always append the new token to the existing completion + llm_event.completion += token @property def current_session_ids(self): @@ -375,20 +468,39 @@ def __init__( endpoint: Optional[str] = None, max_wait_time: Optional[int] = None, max_queue_size: Optional[int] = None, - tags: Optional[List[str]] = None, + default_tags: List[str] = ["langchain", "async"], ): + logging_level = os.getenv("AGENTOPS_LOGGING_LEVEL") + log_levels = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "INFO": logging.INFO, + "WARNING": logging.WARNING, + "DEBUG": logging.DEBUG, + } + logger.setLevel(log_levels.get(logging_level or "INFO", "INFO")) + client_params: Dict[str, Any] = { "api_key": api_key, "endpoint": endpoint, "max_wait_time": max_wait_time, "max_queue_size": max_queue_size, - "tags": tags, + "default_tags": default_tags, } - self.ao_client = AOClient(**{k: v for k, v in client_params.items() if v is not None}, override=False) + self.ao_client = AOClient() + if self.ao_client.session_count == 0: + self.ao_client.configure( + **{k: v for k, v in client_params.items() if v is not None}, + instrument_llm_calls=False, + default_tags=["langchain"], + ) + + if not self.ao_client.is_initialized: + self.ao_client.initialize() - self.events = Events() self.agent_actions: Dict[UUID, List[ActionEvent]] = defaultdict(list) + self.events = Events() @debug_print_function_params async def on_llm_start( @@ -401,14 +513,17 @@ async def on_llm_start( tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Any: + ) -> None: self.events.llm[str(run_id)] = LLMEvent( params={ - **serialized, - **({} if metadata is None else metadata), - **kwargs, - }, # TODO: params is inconsistent, in ToolEvent we put it in logs - model=kwargs["invocation_params"]["model"], + "serialized": serialized, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + "run_id": run_id, + "parent_run_id": parent_run_id, + "tags": tags, + }, + model=get_model_from_kwargs(kwargs), prompt=prompts[0], ) @@ -423,8 +538,44 @@ async def on_chat_model_start( tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Any: - pass + ) -> None: + """Run when a chat model starts running.""" + parsed_messages = [ + {"role": message.type, "content": message.content} + for message in messages[0] + if message.type in ["system", "human"] + ] + + action_event = ActionEvent( + params={ + "serialized": serialized, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + "run_id": run_id, + "parent_run_id": parent_run_id, + "tags": tags, + "messages": parsed_messages, + }, + action_type="on_chat_model_start", + ) + self.ao_client.record(action_event) + + # Initialize LLMEvent here since on_llm_start isn't called for chat models + self.events.llm[str(run_id)] = LLMEvent( + params={ + "serialized": serialized, + "messages": messages, + "run_id": run_id, + "parent_run_id": parent_run_id, + "tags": tags, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + }, + model=get_model_from_kwargs(kwargs), + prompt=parsed_messages, + completion="", + returns={}, + ) @debug_print_function_params async def on_llm_new_token( @@ -437,7 +588,14 @@ async def on_llm_new_token( tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - pass + """Run on new LLM token. Only available when streaming is enabled.""" + if str(run_id) not in self.events.llm: + self.events.llm[str(run_id)] = LLMEvent(params=kwargs) + self.events.llm[str(run_id)].completion = "" + + llm_event = self.events.llm[str(run_id)] + # Always append the new token to the existing completion + llm_event.completion += token @debug_print_function_params async def on_llm_error( @@ -447,9 +605,13 @@ async def on_llm_error( run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ) -> Any: + ) -> None: llm_event: LLMEvent = self.events.llm[str(run_id)] - error_event = ErrorEvent(trigger_event=llm_event, exception=error) + error_event = ErrorEvent( + trigger_event=llm_event, + exception=error, + details={"run_id": run_id, "parent_run_id": parent_run_id, "kwargs": kwargs}, + ) self.ao_client.record(error_event) @debug_print_function_params @@ -460,27 +622,36 @@ async def on_llm_end( run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ) -> Any: + ) -> None: llm_event: LLMEvent = self.events.llm[str(run_id)] - llm_event.returns = { - "content": response.generations[0][0].text, - "generations": response.generations, - } + llm_event.returns = response llm_event.end_timestamp = get_ISO_time() - llm_event.completion = response.generations[0][0].text - if response.llm_output is not None: - llm_event.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"] - llm_event.completion_tokens = response.llm_output["token_usage"]["completion_tokens"] if len(response.generations) == 0: - # TODO: more descriptive error error_event = ErrorEvent( - trigger_event=llm_event, + trigger_event=self.events.llm[str(run_id)], error_type="NoGenerations", - details="on_llm_end: No generations", + details={"run_id": run_id, "parent_run_id": parent_run_id, "kwargs": kwargs}, ) self.ao_client.record(error_event) else: + for generation in response.generations[0]: + if ( + generation.message.type == "AIMessage" + and generation.text + and llm_event.completion != generation.text + ): + llm_event.completion = generation.text + elif ( + generation.message.type == "AIMessageChunk" + and generation.message.content + and llm_event.completion != generation.message.content + ): + llm_event.completion += generation.message.content + + if response.llm_output is not None: + llm_event.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"] + llm_event.completion_tokens = response.llm_output["token_usage"]["completion_tokens"] self.ao_client.record(llm_event) @debug_print_function_params @@ -494,15 +665,23 @@ async def on_chain_start( tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Any: + ) -> None: + # Initialize with empty dicts if None + serialized = serialized or {} + inputs = inputs or {} + metadata = metadata or {} + self.events.chain[str(run_id)] = ActionEvent( params={ - **serialized, - **inputs, - **({} if metadata is None else metadata), - **kwargs, + "serialized": serialized, + "inputs": inputs, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + "run_id": run_id, + "parent_run_id": parent_run_id, + "tags": tags, }, - action_type="chain", + action_type="on_chain_start", ) @debug_print_function_params @@ -513,7 +692,7 @@ async def on_chain_end( run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ) -> Any: + ) -> None: action_event: ActionEvent = self.events.chain[str(run_id)] action_event.returns = outputs action_event.end_timestamp = get_ISO_time() @@ -527,8 +706,12 @@ async def on_chain_error( run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ) -> Any: - action_event: ActionEvent = self.events.chain[str(run_id)] + ) -> None: + # Create a new ActionEvent if one doesn't exist for this run_id + if str(run_id) not in self.events.chain: + self.events.chain[str(run_id)] = ActionEvent(params=kwargs, action_type="on_chain_error") + + action_event = self.events.chain[str(run_id)] error_event = ErrorEvent(trigger_event=action_event, exception=error) self.ao_client.record(error_event) @@ -544,16 +727,18 @@ async def on_tool_start( metadata: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Any: + ) -> None: self.events.tool[str(run_id)] = ToolEvent( - params=input_str if inputs is None else inputs, - name=serialized["name"], + params=inputs, + name=serialized.get("name"), logs={ - **serialized, + "serialized": serialized, + "input_str": input_str, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + "run_id": run_id, + "parent_run_id": parent_run_id, "tags": tags, - **({} if metadata is None else metadata), - **({} if inputs is None else inputs), - **kwargs, }, ) @@ -565,13 +750,11 @@ async def on_tool_end( run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ) -> Any: + ) -> None: tool_event: ToolEvent = self.events.tool[str(run_id)] tool_event.end_timestamp = get_ISO_time() tool_event.returns = output - # Tools are capable of failing `on_tool_end` quietly. - # This is a workaround to make sure we can log it as an error. if kwargs.get("name") == "_Exception": error_event = ErrorEvent( trigger_event=tool_event, @@ -590,7 +773,7 @@ async def on_tool_error( run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ) -> Any: + ) -> None: tool_event: ToolEvent = self.events.tool[str(run_id)] error_event = ErrorEvent(trigger_event=tool_event, exception=error) self.ao_client.record(error_event) @@ -609,12 +792,15 @@ async def on_retriever_start( ) -> None: self.events.retriever[str(run_id)] = ActionEvent( params={ - **serialized, + "serialized": serialized, "query": query, - **({} if metadata is None else metadata), - **kwargs, + "metadata": ({} if metadata is None else metadata), + "kwargs": kwargs, + "run_id": run_id, + "parent_run_id": parent_run_id, + "tags": tags, }, - action_type="retriever", + action_type="on_retriever_start", ) @debug_print_function_params @@ -628,7 +814,7 @@ async def on_retriever_end( **kwargs: Any, ) -> None: action_event: ActionEvent = self.events.retriever[str(run_id)] - action_event.logs = documents # TODO: Adding this. Might want to add elsewhere e.g. params + action_event.returns = documents action_event.end_timestamp = get_ISO_time() self.ao_client.record(action_event) @@ -654,8 +840,10 @@ async def on_agent_action( run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ) -> Any: - self.agent_actions[run_id].append(ActionEvent(params={"action": action, **kwargs}, action_type="agent")) + ) -> None: + self.agent_actions[run_id].append( + ActionEvent(params={"action": action, **kwargs}, action_type="on_agent_action") + ) @debug_print_function_params async def on_agent_finish( @@ -665,28 +853,11 @@ async def on_agent_finish( run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ) -> Any: - # Need to attach finish to some on_agent_action so just choosing the last one + ) -> None: self.agent_actions[run_id][-1].returns = finish.to_json() - for agentAction in self.agent_actions[run_id]: self.ao_client.record(agentAction) - # TODO: Create a way for the end user to set this based on their conditions - # self.ao_client.end_session("Success") #TODO: calling end_session here causes "No current session" - - @debug_print_function_params - async def on_text( - self, - text: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - pass - @debug_print_function_params async def on_retry( self, @@ -695,15 +866,18 @@ async def on_retry( run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ) -> Any: + ) -> None: action_event = ActionEvent( - params={**kwargs}, - returns=str(retry_state), - action_type="retry", - # result="Indeterminate" # TODO: currently have no way of recording Indeterminate + params={ + "retry_state": retry_state, + "run_id": run_id, + "parent_run_id": parent_run_id, + "kwargs": kwargs, + }, + action_type="on_retry", ) self.ao_client.record(action_event) @property - async def session_id(self): - return self.ao_client.current_session_id + def current_session_ids(self): + return self.ao_client.current_session_ids diff --git a/tests/langchain_handlers/_test_langchain_handler.py b/tests/langchain_handlers/_test_langchain_handler.py index 8f468c54b..53652e7d3 100644 --- a/tests/langchain_handlers/_test_langchain_handler.py +++ b/tests/langchain_handlers/_test_langchain_handler.py @@ -1,9 +1,9 @@ import asyncio import os -from langchain.chat_models import ChatOpenAI -from langchain.agents import initialize_agent, AgentType +from langchain_openai import ChatOpenAI +from langchain_core.prompts import ChatPromptTemplate +from langchain.agents import tool, AgentExecutor, create_openai_tools_agent from dotenv import load_dotenv -from langchain.agents import tool from agentops.partners.langchain_callback_handler import ( LangchainCallbackHandler as AgentOpsLangchainCallbackHandler, AsyncLangchainCallbackHandler as AgentOpsAsyncLangchainCallbackHandler, @@ -14,58 +14,91 @@ AGENTOPS_API_KEY = os.environ.get("AGENTOPS_API_KEY") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") -agentops_handler = AgentOpsLangchainCallbackHandler(api_key=AGENTOPS_API_KEY, tags=["Langchain Example"]) -llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, callbacks=[agentops_handler], model="gpt-3.5-turbo") - - -@tool -def find_movie(genre) -> str: - """Find available movies""" - # raise ValueError("This is an intentional error for testing.") - if genre == "drama": - return "Dune 2" - else: - return "Pineapple Express" - - -tools = [find_movie] - -for t in tools: - t.callbacks = [agentops_handler] - -agent = initialize_agent( - tools, - llm, - agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, - verbose=True, - callbacks=[agentops_handler], # You must pass in a callback handler to record your agent - handle_parsing_errors=True, -) - - -agent.run("What comedies are playing?", callbacks=[agentops_handler]) - - -######## -# Async - -agentops_handler = AgentOpsAsyncLangchainCallbackHandler(api_key=AGENTOPS_API_KEY, tags=["Async Example"]) - -llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, callbacks=[agentops_handler], model="gpt-3.5-turbo") - -agent = initialize_agent( - tools, - llm, - agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, - verbose=True, - callbacks=[agentops_handler], # You must pass in a callback handler to record your agent - handle_parsing_errors=True, -) - - -async def run_async(): - await agent.run("What comedies are playing?", callbacks=[agentops_handler]) - - -asyncio.run(run_async()) +# Sync test +def run_sync_test(): + agentops_handler = AgentOpsLangchainCallbackHandler( + api_key=AGENTOPS_API_KEY, default_tags=["Langchain", "Sync Handler Test"] + ) + + llm = ChatOpenAI( + openai_api_key=OPENAI_API_KEY, + callbacks=[agentops_handler], + model="gpt-4o-mini", + streaming=False, # Disable streaming for sync handler + ) + + @tool + def find_movie(genre) -> str: + """Find available movies""" + if genre == "drama": + return "Dune 2" + else: + return "Pineapple Express" + + tools = [find_movie] + for t in tools: + t.callbacks = [agentops_handler] + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful assistant. Respond only in Spanish."), + ("user", "{input}"), + ("system", "Here is the current conversation state:\n{agent_scratchpad}"), + ] + ) + + agent = create_openai_tools_agent(llm, tools, prompt) + agent_executor = AgentExecutor(agent=agent, tools=tools, callbacks=[agentops_handler]) + + return agent_executor.invoke({"input": "What comedies are playing?"}) + + +# Async test +async def run_async_test(): + agentops_handler = AgentOpsAsyncLangchainCallbackHandler( + api_key=AGENTOPS_API_KEY, default_tags=["Langchain", "Async Handler Test"] + ) + + llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, callbacks=[agentops_handler], model="gpt-4o-mini", streaming=True) + + @tool + def find_movie(genre) -> str: + """Find available movies""" + if genre == "drama": + return "Dune 2" + else: + return "Pineapple Express" + + tools = [find_movie] + for t in tools: + t.callbacks = [agentops_handler] + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful assistant. Respond only in Spanish."), + ("user", "{input}"), + ("system", "Here is the current conversation state:\n{agent_scratchpad}"), + ] + ) + + agent = create_openai_tools_agent(llm, tools, prompt) + agent_executor = AgentExecutor(agent=agent, tools=tools, callbacks=[agentops_handler]) + + return await agent_executor.ainvoke({"input": "What comedies are playing?"}) + + +async def main(): + # Run sync test + print("Running sync test...") + sync_result = run_sync_test() + print(f"Sync test result: {sync_result}\n") + + # Run async test + print("Running async test...") + async_result = await run_async_test() + print(f"Async test result: {async_result}") + + +if __name__ == "__main__": + asyncio.run(main())