|
28 | 28 | get_before_llm_call_hooks, |
29 | 29 | ) |
30 | 30 | from crewai.utilities.agent_utils import ( |
| 31 | + aget_llm_response, |
31 | 32 | enforce_rpm_limit, |
32 | 33 | format_message_for_llm, |
33 | 34 | get_llm_response, |
|
43 | 44 | from crewai.utilities.constants import TRAINING_DATA_FILE |
44 | 45 | from crewai.utilities.i18n import I18N, get_i18n |
45 | 46 | from crewai.utilities.printer import Printer |
46 | | -from crewai.utilities.tool_utils import execute_tool_and_check_finality |
| 47 | +from crewai.utilities.tool_utils import ( |
| 48 | + aexecute_tool_and_check_finality, |
| 49 | + execute_tool_and_check_finality, |
| 50 | +) |
47 | 51 | from crewai.utilities.training_handler import CrewTrainingHandler |
48 | 52 |
|
49 | 53 |
|
@@ -134,8 +138,8 @@ def __init__( |
134 | 138 | self.messages: list[LLMMessage] = [] |
135 | 139 | self.iterations = 0 |
136 | 140 | self.log_error_after = 3 |
137 | | - self.before_llm_call_hooks: list[Callable] = [] |
138 | | - self.after_llm_call_hooks: list[Callable] = [] |
| 141 | + self.before_llm_call_hooks: list[Callable[..., Any]] = [] |
| 142 | + self.after_llm_call_hooks: list[Callable[..., Any]] = [] |
139 | 143 | self.before_llm_call_hooks.extend(get_before_llm_call_hooks()) |
140 | 144 | self.after_llm_call_hooks.extend(get_after_llm_call_hooks()) |
141 | 145 | if self.llm: |
@@ -312,6 +316,154 @@ def _invoke_loop(self) -> AgentFinish: |
312 | 316 | self._show_logs(formatted_answer) |
313 | 317 | return formatted_answer |
314 | 318 |
|
| 319 | + async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]: |
| 320 | + """Execute the agent asynchronously with given inputs. |
| 321 | +
|
| 322 | + Args: |
| 323 | + inputs: Input dictionary containing prompt variables. |
| 324 | +
|
| 325 | + Returns: |
| 326 | + Dictionary with agent output. |
| 327 | + """ |
| 328 | + if "system" in self.prompt: |
| 329 | + system_prompt = self._format_prompt( |
| 330 | + cast(str, self.prompt.get("system", "")), inputs |
| 331 | + ) |
| 332 | + user_prompt = self._format_prompt( |
| 333 | + cast(str, self.prompt.get("user", "")), inputs |
| 334 | + ) |
| 335 | + self.messages.append(format_message_for_llm(system_prompt, role="system")) |
| 336 | + self.messages.append(format_message_for_llm(user_prompt)) |
| 337 | + else: |
| 338 | + user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs) |
| 339 | + self.messages.append(format_message_for_llm(user_prompt)) |
| 340 | + |
| 341 | + self._show_start_logs() |
| 342 | + |
| 343 | + self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False)) |
| 344 | + |
| 345 | + try: |
| 346 | + formatted_answer = await self._ainvoke_loop() |
| 347 | + except AssertionError: |
| 348 | + self._printer.print( |
| 349 | + content="Agent failed to reach a final answer. This is likely a bug - please report it.", |
| 350 | + color="red", |
| 351 | + ) |
| 352 | + raise |
| 353 | + except Exception as e: |
| 354 | + handle_unknown_error(self._printer, e) |
| 355 | + raise |
| 356 | + |
| 357 | + if self.ask_for_human_input: |
| 358 | + formatted_answer = self._handle_human_feedback(formatted_answer) |
| 359 | + |
| 360 | + self._create_short_term_memory(formatted_answer) |
| 361 | + self._create_long_term_memory(formatted_answer) |
| 362 | + self._create_external_memory(formatted_answer) |
| 363 | + return {"output": formatted_answer.output} |
| 364 | + |
| 365 | + async def _ainvoke_loop(self) -> AgentFinish: |
| 366 | + """Execute agent loop asynchronously until completion. |
| 367 | +
|
| 368 | + Returns: |
| 369 | + Final answer from the agent. |
| 370 | + """ |
| 371 | + formatted_answer = None |
| 372 | + while not isinstance(formatted_answer, AgentFinish): |
| 373 | + try: |
| 374 | + if has_reached_max_iterations(self.iterations, self.max_iter): |
| 375 | + formatted_answer = handle_max_iterations_exceeded( |
| 376 | + formatted_answer, |
| 377 | + printer=self._printer, |
| 378 | + i18n=self._i18n, |
| 379 | + messages=self.messages, |
| 380 | + llm=self.llm, |
| 381 | + callbacks=self.callbacks, |
| 382 | + ) |
| 383 | + break |
| 384 | + |
| 385 | + enforce_rpm_limit(self.request_within_rpm_limit) |
| 386 | + |
| 387 | + answer = await aget_llm_response( |
| 388 | + llm=self.llm, |
| 389 | + messages=self.messages, |
| 390 | + callbacks=self.callbacks, |
| 391 | + printer=self._printer, |
| 392 | + from_task=self.task, |
| 393 | + from_agent=self.agent, |
| 394 | + response_model=self.response_model, |
| 395 | + executor_context=self, |
| 396 | + ) |
| 397 | + formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment] |
| 398 | + |
| 399 | + if isinstance(formatted_answer, AgentAction): |
| 400 | + fingerprint_context = {} |
| 401 | + if ( |
| 402 | + self.agent |
| 403 | + and hasattr(self.agent, "security_config") |
| 404 | + and hasattr(self.agent.security_config, "fingerprint") |
| 405 | + ): |
| 406 | + fingerprint_context = { |
| 407 | + "agent_fingerprint": str( |
| 408 | + self.agent.security_config.fingerprint |
| 409 | + ) |
| 410 | + } |
| 411 | + |
| 412 | + tool_result = await aexecute_tool_and_check_finality( |
| 413 | + agent_action=formatted_answer, |
| 414 | + fingerprint_context=fingerprint_context, |
| 415 | + tools=self.tools, |
| 416 | + i18n=self._i18n, |
| 417 | + agent_key=self.agent.key if self.agent else None, |
| 418 | + agent_role=self.agent.role if self.agent else None, |
| 419 | + tools_handler=self.tools_handler, |
| 420 | + task=self.task, |
| 421 | + agent=self.agent, |
| 422 | + function_calling_llm=self.function_calling_llm, |
| 423 | + crew=self.crew, |
| 424 | + ) |
| 425 | + formatted_answer = self._handle_agent_action( |
| 426 | + formatted_answer, tool_result |
| 427 | + ) |
| 428 | + |
| 429 | + self._invoke_step_callback(formatted_answer) # type: ignore[arg-type] |
| 430 | + self._append_message(formatted_answer.text) # type: ignore[union-attr,attr-defined] |
| 431 | + |
| 432 | + except OutputParserError as e: |
| 433 | + formatted_answer = handle_output_parser_exception( # type: ignore[assignment] |
| 434 | + e=e, |
| 435 | + messages=self.messages, |
| 436 | + iterations=self.iterations, |
| 437 | + log_error_after=self.log_error_after, |
| 438 | + printer=self._printer, |
| 439 | + ) |
| 440 | + |
| 441 | + except Exception as e: |
| 442 | + if e.__class__.__module__.startswith("litellm"): |
| 443 | + raise e |
| 444 | + if is_context_length_exceeded(e): |
| 445 | + handle_context_length( |
| 446 | + respect_context_window=self.respect_context_window, |
| 447 | + printer=self._printer, |
| 448 | + messages=self.messages, |
| 449 | + llm=self.llm, |
| 450 | + callbacks=self.callbacks, |
| 451 | + i18n=self._i18n, |
| 452 | + ) |
| 453 | + continue |
| 454 | + handle_unknown_error(self._printer, e) |
| 455 | + raise e |
| 456 | + finally: |
| 457 | + self.iterations += 1 |
| 458 | + |
| 459 | + if not isinstance(formatted_answer, AgentFinish): |
| 460 | + raise RuntimeError( |
| 461 | + "Agent execution ended without reaching a final answer. " |
| 462 | + f"Got {type(formatted_answer).__name__} instead of AgentFinish." |
| 463 | + ) |
| 464 | + self._show_logs(formatted_answer) |
| 465 | + return formatted_answer |
| 466 | + |
315 | 467 | def _handle_agent_action( |
316 | 468 | self, formatted_answer: AgentAction, tool_result: ToolResult |
317 | 469 | ) -> AgentAction | AgentFinish: |
|
0 commit comments