diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index 53c43acd56847..ff4a63b792f7e 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -1,8 +1,7 @@ """Convert functions and runnables to tools.""" import inspect -from collections.abc import Callable -from typing import Any, Literal, get_type_hints, overload +from typing import Any, Callable, Literal, Optional, Union, get_type_hints, overload from pydantic import BaseModel, Field, create_model @@ -16,14 +15,14 @@ @overload def tool( *, - description: str | None = None, + description: Optional[str] = None, return_direct: bool = False, - args_schema: ArgsSchema | None = None, + args_schema: Optional[ArgsSchema] = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, -) -> Callable[[Callable | Runnable], BaseTool]: ... +) -> Callable[[Union[Callable, Runnable]], BaseTool]: ... @overload @@ -31,9 +30,9 @@ def tool( name_or_callable: str, runnable: Runnable, *, - description: str | None = None, + description: Optional[str] = None, return_direct: bool = False, - args_schema: ArgsSchema | None = None, + args_schema: Optional[ArgsSchema] = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, @@ -45,9 +44,9 @@ def tool( def tool( name_or_callable: Callable, *, - description: str | None = None, + description: Optional[str] = None, return_direct: bool = False, - args_schema: ArgsSchema | None = None, + args_schema: Optional[ArgsSchema] = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, @@ -59,28 +58,31 @@ def tool( def tool( name_or_callable: str, *, - description: str | None = None, + description: Optional[str] = None, return_direct: bool = False, - args_schema: ArgsSchema | None = None, + args_schema: Optional[ArgsSchema] = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, -) -> Callable[[Callable | Runnable], BaseTool]: ... +) -> Callable[[Union[Callable, Runnable]], BaseTool]: ... def tool( - name_or_callable: str | Callable | None = None, - runnable: Runnable | None = None, + name_or_callable: Optional[Union[str, Callable]] = None, + runnable: Optional[Runnable] = None, *args: Any, - description: str | None = None, + description: Optional[str] = None, return_direct: bool = False, - args_schema: ArgsSchema | None = None, + args_schema: Optional[ArgsSchema] = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, -) -> BaseTool | Callable[[Callable | Runnable], BaseTool]: +) -> Union[ + BaseTool, + Callable[[Union[Callable, Runnable]], BaseTool], +]: """Make tools out of functions, can be used with or without arguments. Args: @@ -155,7 +157,7 @@ def search_api(query: str) -> str: def search_api(query: str) -> tuple[str, dict]: return "partial json of results", {"full": "object of results"} - !!! version-added "Added in version 0.2.14" + .. versionadded:: 0.2.14 Parse Google-style docstrings: @@ -229,7 +231,7 @@ def invalid_docstring_3(bar: str, baz: int) -> str: def _create_tool_factory( tool_name: str, - ) -> Callable[[Callable | Runnable], BaseTool]: + ) -> Callable[[Union[Callable, Runnable]], BaseTool]: """Create a decorator that takes a callable and returns a tool. Args: @@ -239,7 +241,7 @@ def _create_tool_factory( A function that takes a callable or Runnable and returns a tool. """ - def _tool_factory(dec_func: Callable | Runnable) -> BaseTool: + def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool: tool_description = description if isinstance(dec_func, Runnable): runnable = dec_func @@ -249,18 +251,18 @@ def _tool_factory(dec_func: Callable | Runnable) -> BaseTool: raise ValueError(msg) async def ainvoke_wrapper( - callbacks: Callbacks | None = None, **kwargs: Any + callbacks: Optional[Callbacks] = None, **kwargs: Any ) -> Any: return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) def invoke_wrapper( - callbacks: Callbacks | None = None, **kwargs: Any + callbacks: Optional[Callbacks] = None, **kwargs: Any ) -> Any: return runnable.invoke(kwargs, {"callbacks": callbacks}) coroutine = ainvoke_wrapper func = invoke_wrapper - schema: ArgsSchema | None = runnable.input_schema + schema: Optional[ArgsSchema] = runnable.input_schema tool_description = description or repr(runnable) elif inspect.iscoroutinefunction(dec_func): coroutine = dec_func @@ -350,7 +352,7 @@ def invoke_wrapper( # @tool(parse_docstring=True) # def my_tool(): # pass - def _partial(func: Callable | Runnable) -> BaseTool: + def _partial(func: Union[Callable, Runnable]) -> BaseTool: """Partial function that takes a callable and returns a tool.""" name_ = func.get_name() if isinstance(func, Runnable) else func.__name__ tool_factory = _create_tool_factory(name_) @@ -368,7 +370,7 @@ def _get_description_from_runnable(runnable: Runnable) -> str: def _get_schema_from_runnable_and_arg_types( runnable: Runnable, name: str, - arg_types: dict[str, type] | None = None, + arg_types: Optional[dict[str, type]] = None, ) -> type[BaseModel]: """Infer args_schema for tool.""" if arg_types is None: @@ -387,11 +389,11 @@ def _get_schema_from_runnable_and_arg_types( def convert_runnable_to_tool( runnable: Runnable, - args_schema: type[BaseModel] | None = None, + args_schema: Optional[type[BaseModel]] = None, *, - name: str | None = None, - description: str | None = None, - arg_types: dict[str, type] | None = None, + name: Optional[str] = None, + description: Optional[str] = None, + arg_types: Optional[dict[str, type]] = None, ) -> BaseTool: """Convert a Runnable into a BaseTool. @@ -419,10 +421,12 @@ def convert_runnable_to_tool( description=description, ) - async def ainvoke_wrapper(callbacks: Callbacks | None = None, **kwargs: Any) -> Any: + async def ainvoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: return await runnable.ainvoke(kwargs, config={"callbacks": callbacks}) - def invoke_wrapper(callbacks: Callbacks | None = None, **kwargs: Any) -> Any: + def invoke_wrapper(callbacks: Optional[Callbacks] = None, **kwargs: Any) -> Any: return runnable.invoke(kwargs, config={"callbacks": callbacks}) if ( @@ -442,4 +446,4 @@ def invoke_wrapper(callbacks: Callbacks | None = None, **kwargs: Any) -> Any: coroutine=ainvoke_wrapper, description=description, args_schema=args_schema, - ) + ) \ No newline at end of file diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index bf98a980eee8c..6becaba9b99e4 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -97,6 +97,55 @@ def search_api(query: str) -> str: assert search_api.invoke("test") == "API result" +def test_tool_decorator_type_annotations() -> None: + """Test that tool decorator has proper type annotations for strict type checking.""" + # This test verifies the fix for GitHub issue #33019 + # Previously, importing tool in strict Pylance mode would show: + # "Type of 'tool' is partially unknown" + + # Test basic decorator usage + @tool + def basic_tool(query: str) -> str: + """Basic tool.""" + return query + + # Test decorator with custom name + @tool("custom_name") + def named_tool(x: int) -> str: + """Named tool.""" + return str(x) + + # Test decorator with keyword arguments + @tool(description="Custom description") + def described_tool(value: float) -> str: + """Described tool.""" + return str(value) + + # Test direct function conversion + def raw_func(data: str) -> str: + """Raw function.""" + return data + + direct_tool = tool(raw_func) + + # Verify all tools are created correctly + assert isinstance(basic_tool, BaseTool) + assert isinstance(named_tool, BaseTool) + assert isinstance(described_tool, BaseTool) + assert isinstance(direct_tool, BaseTool) + + assert basic_tool.name == "basic_tool" + assert named_tool.name == "custom_name" + assert described_tool.name == "described_tool" + assert direct_tool.name == "raw_func" + + # Test descriptions + assert basic_tool.description == "Basic tool." + assert named_tool.description == "Named tool." + assert described_tool.description == "Custom description" + assert direct_tool.description == "Raw function." + + class _MockSchema(BaseModel): """Return the arguments directly."""