Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 92 additions & 21 deletions python/beeai_framework/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from beeai_framework.emitter.emitter import Emitter
from beeai_framework.errors import FrameworkError
from beeai_framework.logger import Logger
from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput
from beeai_framework.retryable import (
Retryable,
RetryableConfig,
RetryableContext,
RetryableInput,
)
from beeai_framework.tools.errors import ToolError, ToolInputValidationError
from beeai_framework.tools.events import (
ToolErrorEvent,
Expand All @@ -39,12 +44,34 @@


class Tool(Generic[TInput, TRunOptions, TOutput], ABC):
"""
Base abstraction for all BeeAI tools.

A Tool represents an executable unit that:
- validates structured input using a Pydantic schema
- executes logic via the `_run` method
- optionally caches results
- emits lifecycle events (start, retry, success, error)

Subclasses must implement:
- `name`
- `description`
- `input_schema`
- `_create_emitter`
- `_run`
"""

def __init__(self, options: dict[str, Any] | None = None) -> None:
self._options: dict[str, Any] | None = options or None
self._cache = self.options.get("cache", NullCache[TOutput]()) if self.options else NullCache[TOutput]()
self._cache = (
self.options.get("cache", NullCache[TOutput]())
if self.options
else NullCache[TOutput]()
)
self.middlewares: list[RunMiddlewareType] = []

def __str__(self) -> str:
"""Return the tool name for readable string representation."""
return self.name

@property
Expand Down Expand Up @@ -92,10 +119,14 @@ def _create_emitter(self) -> Emitter:
pass

@abstractmethod
async def _run(self, input: TInput, options: TRunOptions | None, context: RunContext) -> TOutput:
async def _run(
self, input: TInput, options: TRunOptions | None, context: RunContext
) -> TOutput:
pass

def _generate_key(self, input: TInput | dict[str, Any], options: TRunOptions | None = None) -> str:
def _generate_key(
self, input: TInput | dict[str, Any], options: TRunOptions | None = None
) -> str:
options_dict = options.model_dump(exclude_none=True) if options else {}
options_dict.pop("signal", None)
options_dict.pop("retry_options", None)
Expand All @@ -108,9 +139,13 @@ def _validate_input(self, input: TInput | dict[str, Any]) -> TInput:
try:
return self.input_schema.model_validate(input)
except ValidationError as e:
raise ToolInputValidationError("Tool input validation error", cause=e)
raise ToolInputValidationError(
f"Input validation failed for tool '{self.name}'", cause=e
)

def run(self, input: TInput | dict[str, Any], options: TRunOptions | None = None) -> Run[TOutput]:
def run(
self, input: TInput | dict[str, Any], options: TRunOptions | None = None
) -> Run[TOutput]:
async def handler(context: RunContext) -> TOutput:
error_propagated = False

Expand All @@ -120,7 +155,9 @@ async def handler(context: RunContext) -> TOutput:
async def executor(_: RetryableContext) -> TOutput:
nonlocal error_propagated
error_propagated = False
await context.emitter.emit("start", ToolStartEvent(input=validated_input, options=options))
await context.emitter.emit(
"start", ToolStartEvent(input=validated_input, options=options)
)

if self.cache.enabled:
cache_key = self._generate_key(input, options)
Expand All @@ -140,15 +177,23 @@ async def on_error(error: Exception, _: RetryableContext) -> None:
error_propagated = True
err = ToolError.ensure(error)
await context.emitter.emit(
"error", ToolErrorEvent(error=err, input=validated_input, options=options)
"error",
ToolErrorEvent(
error=err, input=validated_input, options=options
),
)
if FrameworkError.is_fatal(err) is True:
raise err

async def on_retry(ctx: RetryableContext, last_error: Exception) -> None:
async def on_retry(
ctx: RetryableContext, last_error: Exception
) -> None:
err = ToolError.ensure(last_error)
await context.emitter.emit(
"retry", ToolRetryEvent(error=err, input=validated_input, options=options)
"retry",
ToolRetryEvent(
error=err, input=validated_input, options=options
),
)

output = await Retryable(
Expand All @@ -158,22 +203,33 @@ async def on_retry(ctx: RetryableContext, last_error: Exception) -> None:
on_retry=on_retry,
config=RetryableConfig(
max_retries=(
(options.retry_options.max_retries or 0) if options and options.retry_options else 0
(options.retry_options.max_retries or 0)
if options and options.retry_options
else 0
),
factor=(
(options.retry_options.factor or 1)
if options and options.retry_options
else 1
),
factor=((options.retry_options.factor or 1) if options and options.retry_options else 1),
signal=context.signal,
),
)
).get()

await context.emitter.emit(
"success", ToolSuccessEvent(output=output, input=validated_input, options=options)
"success",
ToolSuccessEvent(
output=output, input=validated_input, options=options
),
)
return output
except Exception as e:
err = ToolError.ensure(e, tool=self)
if not error_propagated:
await context.emitter.emit("error", ToolErrorEvent(error=err, input=input, options=options))
await context.emitter.emit(
"error", ToolErrorEvent(error=err, input=input, options=options)
)
raise err
finally:
await context.emitter.emit("finish", None)
Expand All @@ -187,18 +243,24 @@ async def on_retry(ctx: RetryableContext, last_error: Exception) -> None:

async def clone(self) -> Self:
if type(self).clone == Tool.clone:
logging.warning(f"Tool '{self.name}' does not implement the 'clone' method.")
logging.warning(
f"Tool '{self.name}' does not implement the 'clone' method."
)

return self


# this method was inspired by the discussion that was had in this issue:
# https://github.com/pydantic/pydantic/issues/1391
@typing.no_type_check
def get_input_schema(tool_function: Callable, *, name: str | None = None) -> type[BaseModel]:
def get_input_schema(
tool_function: Callable, *, name: str | None = None
) -> type[BaseModel]:
input_model_name = name or tool_function.__name__

args, _, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = inspect.getfullargspec(tool_function)
args, _, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = (
inspect.getfullargspec(tool_function)
)
defaults = defaults or []
args = args or []

Expand All @@ -210,14 +272,21 @@ def get_input_schema(tool_function: Callable, *, name: str | None = None) -> typ
...,
] * non_default_args + defaults

keyword_only_params = {param: kwonlydefaults.get(param, Any) for param in kwonlyargs}
params = {param: (annotations.get(param, Any), default) for param, default in zip(args, defaults, strict=False)}
keyword_only_params = {
param: kwonlydefaults.get(param, Any) for param in kwonlyargs
}
params = {
param: (annotations.get(param, Any), default)
for param, default in zip(args, defaults, strict=False)
}

input_model = create_model(
input_model_name,
**params,
**keyword_only_params,
__config__=ConfigDict(extra="allow" if varkw else "ignore", arbitrary_types_allowed=True),
__config__=ConfigDict(
extra="allow" if varkw else "ignore", arbitrary_types_allowed=True
),
)

return input_model
Expand Down Expand Up @@ -282,7 +351,9 @@ def _create_emitter(self) -> Emitter:
creator=self,
)

async def _run(self, input: Any, options: ToolRunOptions | None, context: RunContext) -> ToolOutput:
async def _run(
self, input: Any, options: ToolRunOptions | None, context: RunContext
) -> ToolOutput:
tool_input_dict = input.model_dump()
if with_context:
tool_input_dict["context"] = context
Expand Down
Loading