Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/google/adk/tools/_function_parameter_parse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from google.genai import types
import pydantic
from enum import Enum

from ..utils.variant_utils import GoogleLLMVariant

Expand Down Expand Up @@ -75,7 +76,7 @@ def _raise_if_schema_unsupported(
):
if variant == GoogleLLMVariant.GEMINI_API:
_raise_for_any_of_if_mldev(schema)
_update_for_default_if_mldev(schema)
# _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value


def _is_default_value_compatible(
Expand Down Expand Up @@ -145,6 +146,16 @@ def _parse_schema_from_parameter(
schema.type = _py_builtin_type_to_schema_type[param.annotation]
_raise_if_schema_unsupported(variant, schema)
return schema
if isinstance(param.annotation, type) and issubclass(param.annotation, Enum):
schema.type = types.Type.STRING
schema.enum = [e.value for e in param.annotation]
if param.default is not inspect.Parameter.empty:
default_value = param.default.value if isinstance(param.default, Enum) else param.default
if default_value not in schema.enum:
raise ValueError(default_value_error_msg)
schema.default = default_value
_raise_if_schema_unsupported(variant, schema)
return schema
if (
get_origin(param.annotation) is Union
# only parse simple UnionType, example int | str | float | bool
Expand Down
29 changes: 28 additions & 1 deletion tests/unittests/tools/test_build_function_declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
# TODO: crewai requires python 3.10 as minimum
# from crewai_tools import FileReadTool
from pydantic import BaseModel

from enum import Enum
import pytest

def test_string_input():
def simple_function(input_str: str) -> str:
Expand Down Expand Up @@ -219,6 +220,32 @@ def simple_function(
assert function_decl.parameters.properties['input_dir'].type == 'ARRAY'
assert function_decl.parameters.properties['input_dir'].items.type == 'OBJECT'

def test_enums():

class InputEnum(Enum):
AGENT = "agent"
TOOL = "tool"

def simple_function(input:InputEnum=InputEnum.AGENT):
return input.value

function_decl = _automatic_function_calling_util.build_function_declaration(
func=simple_function
)

assert function_decl.name == 'simple_function'
assert function_decl.parameters.type == 'OBJECT'
assert function_decl.parameters.properties['input'].type == 'STRING'
assert function_decl.parameters.properties['input'].default == 'agent'
assert function_decl.parameters.properties['input'].enum == ['agent', 'tool']

def simple_function_with_wrong_enum(input:InputEnum="WRONG_ENUM"):
return input.value

with pytest.raises(ValueError):
_automatic_function_calling_util.build_function_declaration(
func=simple_function_with_wrong_enum
)

def test_basemodel_list():
class ChildInput(BaseModel):
Expand Down