|
| 1 | +import sys |
| 2 | +import types |
| 3 | +import pytest |
| 4 | +import json |
| 5 | +import inspect |
| 6 | +from typing import Annotated, List, Dict |
| 7 | +from unittest.mock import patch |
| 8 | + |
| 9 | +# ----- Mocking semantic_kernel.functions.kernel_function ----- |
| 10 | +semantic_kernel = types.ModuleType("semantic_kernel") |
| 11 | +semantic_kernel.functions = types.ModuleType("functions") |
| 12 | + |
| 13 | +def mock_kernel_function(*args, **kwargs): |
| 14 | + def decorator(func): |
| 15 | + func.__kernel_function__ = types.SimpleNamespace(**kwargs) |
| 16 | + return func |
| 17 | + return decorator |
| 18 | + |
| 19 | +semantic_kernel.functions.kernel_function = mock_kernel_function |
| 20 | +sys.modules["semantic_kernel"] = semantic_kernel |
| 21 | +sys.modules["semantic_kernel.functions"] = semantic_kernel.functions |
| 22 | +# ------------------------------------------------------------- |
| 23 | + |
| 24 | +# ----- Mocking models.messages_kernel.AgentType ----- |
| 25 | +mock_models = types.ModuleType("models") |
| 26 | +mock_messages_kernel = types.ModuleType("models.messages_kernel") |
| 27 | + |
| 28 | +class AgentType: |
| 29 | + GENERIC = type("AgentValue", (), {"value": "generic"}) |
| 30 | + |
| 31 | +mock_messages_kernel.AgentType = AgentType |
| 32 | +mock_models.messages_kernel = mock_messages_kernel |
| 33 | + |
| 34 | +sys.modules["models"] = mock_models |
| 35 | +sys.modules["models.messages_kernel"] = mock_messages_kernel |
| 36 | +# ---------------------------------------------------- |
| 37 | + |
| 38 | +from src.backend.kernel_tools.generic_tools import GenericTools |
| 39 | +from semantic_kernel.functions import kernel_function |
| 40 | + |
| 41 | +# ----------------- Inject kernel_function examples ----------------- |
| 42 | + |
| 43 | +@kernel_function(description="Add two integers") |
| 44 | +async def add_numbers(a: int, b: int) -> int: |
| 45 | + """Adds two numbers""" |
| 46 | + return a + b |
| 47 | + |
| 48 | +GenericTools.add_numbers = staticmethod(add_numbers) |
| 49 | + |
| 50 | +@kernel_function(description="Add two integers") |
| 51 | +async def add(x: int, y: int) -> int: |
| 52 | + return x + y |
| 53 | + |
| 54 | +@kernel_function(description="Subtract two integers") |
| 55 | +async def subtract(x: int, y: int) -> int: |
| 56 | + return x - y |
| 57 | + |
| 58 | +@kernel_function |
| 59 | +async def only_docstring(x: int) -> int: |
| 60 | + """Docstring exists""" |
| 61 | + return x |
| 62 | + |
| 63 | +@kernel_function(description="Has cls parameter") |
| 64 | +async def func_with_cls(cls, param: int) -> int: |
| 65 | + return param |
| 66 | + |
| 67 | +@kernel_function(description="Sample") |
| 68 | +async def sample(x: int) -> int: |
| 69 | + return x |
| 70 | + |
| 71 | +@kernel_function(description="Annotated param") |
| 72 | +async def annotated_param(x: Annotated[int, "Some metadata"]) -> int: |
| 73 | + return x |
| 74 | + |
| 75 | +# ------------------------- Tests ------------------------- |
| 76 | + |
| 77 | +def test_get_all_kernel_functions_includes_add_numbers(): |
| 78 | + functions = GenericTools.get_all_kernel_functions() |
| 79 | + assert "add_numbers" in functions |
| 80 | + assert inspect.iscoroutinefunction(functions["add_numbers"]) |
| 81 | + |
| 82 | +def test_generate_tools_json_doc_includes_add_numbers_arguments(): |
| 83 | + json_doc = GenericTools.generate_tools_json_doc() |
| 84 | + parsed = json.loads(json_doc) |
| 85 | + found = False |
| 86 | + for tool in parsed: |
| 87 | + if tool["function"] == "add_numbers": |
| 88 | + found = True |
| 89 | + args = json.loads(tool["arguments"].replace("'", '"')) |
| 90 | + assert "a" in args |
| 91 | + assert args["a"]["type"] == "int" |
| 92 | + assert args["a"]["title"] == "A" |
| 93 | + assert args["a"]["description"] == "a" |
| 94 | + assert "b" in args |
| 95 | + assert args["b"]["type"] == "int" |
| 96 | + assert found |
| 97 | + |
| 98 | +def test_generate_tools_json_doc_handles_non_kernel_function(): |
| 99 | + class Dummy(GenericTools): |
| 100 | + @staticmethod |
| 101 | + def regular_function(): |
| 102 | + pass |
| 103 | + Dummy.agent_name = "dummy" |
| 104 | + json_doc = Dummy.generate_tools_json_doc() |
| 105 | + parsed = json.loads(json_doc) |
| 106 | + assert all(tool["function"] != "regular_function" for tool in parsed) |
| 107 | + |
| 108 | +def test_get_all_kernel_functions_no_kernel_functions(): |
| 109 | + class Dummy(GenericTools): |
| 110 | + pass |
| 111 | + functions = Dummy.get_all_kernel_functions() |
| 112 | + own_functions = {name: fn for name, fn in functions.items() if name in Dummy.__dict__} |
| 113 | + assert own_functions == {} |
| 114 | + |
| 115 | +def test_get_all_kernel_functions_multiple_kernel_functions(): |
| 116 | + class Dummy(GenericTools): |
| 117 | + add = staticmethod(add) |
| 118 | + subtract = staticmethod(subtract) |
| 119 | + dummy = Dummy() |
| 120 | + funcs = dummy.get_all_kernel_functions() |
| 121 | + assert "add" in funcs |
| 122 | + assert "subtract" in funcs |
| 123 | + |
| 124 | +def test_generate_tools_json_doc_no_arguments(): |
| 125 | + @kernel_function(description="Return a constant string") |
| 126 | + async def return_constant() -> str: |
| 127 | + return "Constant" |
| 128 | + GenericTools.return_constant = staticmethod(return_constant) |
| 129 | + json_doc = GenericTools.generate_tools_json_doc() |
| 130 | + parsed = json.loads(json_doc) |
| 131 | + tool = next((t for t in parsed if t["function"] == "return_constant"), None) |
| 132 | + assert tool is not None |
| 133 | + assert json.loads(tool["arguments"].replace("'", '"')) == {} |
| 134 | + |
| 135 | +def test_generate_tools_json_doc_complex_argument_types(): |
| 136 | + @kernel_function(description="Process a list of integers") |
| 137 | + async def process_list(numbers: List[int]) -> int: |
| 138 | + return sum(numbers) |
| 139 | + @kernel_function(description="Process a dictionary") |
| 140 | + async def process_dict(data: Dict[str, int]) -> int: |
| 141 | + return sum(data.values()) |
| 142 | + GenericTools.process_list = staticmethod(process_list) |
| 143 | + GenericTools.process_dict = staticmethod(process_dict) |
| 144 | + parsed = json.loads(GenericTools.generate_tools_json_doc()) |
| 145 | + tool1 = next((t for t in parsed if t["function"] == "process_list"), None) |
| 146 | + assert tool1 is not None |
| 147 | + assert json.loads(tool1["arguments"].replace("'", '"'))["numbers"]["type"] == "list" |
| 148 | + tool2 = next((t for t in parsed if t["function"] == "process_dict"), None) |
| 149 | + assert tool2 is not None |
| 150 | + assert json.loads(tool2["arguments"].replace("'", '"'))["data"]["type"] == "dict" |
| 151 | + |
| 152 | +def test_generate_tools_json_doc_boolean_argument_type(): |
| 153 | + @kernel_function(description="Toggle a feature") |
| 154 | + async def toggle_feature(enabled: bool) -> str: |
| 155 | + return "Enabled" if enabled else "Disabled" |
| 156 | + GenericTools.toggle_feature = staticmethod(toggle_feature) |
| 157 | + parsed = json.loads(GenericTools.generate_tools_json_doc()) |
| 158 | + tool = next((t for t in parsed if t["function"] == "toggle_feature"), None) |
| 159 | + assert tool is not None |
| 160 | + assert json.loads(tool["arguments"].replace("'", '"'))["enabled"]["type"] == "bool" |
| 161 | + |
| 162 | +def test_generate_tools_json_doc_float_argument_type(): |
| 163 | + @kernel_function(description="Multiply a number") |
| 164 | + async def multiply_by_two(value: float) -> float: |
| 165 | + return value * 2 |
| 166 | + GenericTools.multiply_by_two = staticmethod(multiply_by_two) |
| 167 | + parsed = json.loads(GenericTools.generate_tools_json_doc()) |
| 168 | + tool = next((t for t in parsed if t["function"] == "multiply_by_two"), None) |
| 169 | + assert tool is not None |
| 170 | + assert json.loads(tool["arguments"].replace("'", '"'))["value"]["type"] == "float" |
| 171 | + |
| 172 | + |
| 173 | + |
| 174 | +def test_generate_tools_json_doc_raw_list_type(): |
| 175 | + @kernel_function(description="Accept raw list type") |
| 176 | + async def accept_raw_list(items: list) -> int: |
| 177 | + return len(items) |
| 178 | + GenericTools.accept_raw_list = staticmethod(accept_raw_list) |
| 179 | + parsed = json.loads(GenericTools.generate_tools_json_doc()) |
| 180 | + tool = next((t for t in parsed if t["function"] == "accept_raw_list"), None) |
| 181 | + assert tool is not None |
| 182 | + assert json.loads(tool["arguments"].replace("'", '"'))["items"]["type"] == "list" |
| 183 | + |
| 184 | +def test_generate_tools_json_doc_raw_dict_type(): |
| 185 | + @kernel_function(description="Accept raw dict type") |
| 186 | + async def accept_raw_dict(data: dict) -> int: |
| 187 | + return len(data) |
| 188 | + GenericTools.accept_raw_dict = staticmethod(accept_raw_dict) |
| 189 | + parsed = json.loads(GenericTools.generate_tools_json_doc()) |
| 190 | + tool = next((t for t in parsed if t["function"] == "accept_raw_dict"), None) |
| 191 | + assert tool is not None |
| 192 | + assert json.loads(tool["arguments"].replace("'", '"'))["data"]["type"] == "dict" |
| 193 | + |
| 194 | +def test_generate_tools_json_doc_fallback_type(): |
| 195 | + class CustomType: |
| 196 | + pass |
| 197 | + @kernel_function(description="Uses custom type") |
| 198 | + async def use_custom_type(param: CustomType) -> str: |
| 199 | + return "ok" |
| 200 | + GenericTools.use_custom_type = staticmethod(use_custom_type) |
| 201 | + parsed = json.loads(GenericTools.generate_tools_json_doc()) |
| 202 | + tool = next((t for t in parsed if t["function"] == "use_custom_type"), None) |
| 203 | + assert tool is not None |
| 204 | + assert json.loads(tool["arguments"].replace("'", '"'))["param"]["type"] == "customtype" |
| 205 | + |
| 206 | + |
| 207 | + |
| 208 | +def test_generate_tools_json_doc_skips_cls_param(): |
| 209 | + GenericTools.func_with_cls = staticmethod(func_with_cls) |
| 210 | + parsed = json.loads(GenericTools.generate_tools_json_doc()) |
| 211 | + tool = next((t for t in parsed if t["function"] == "func_with_cls"), None) |
| 212 | + assert tool is not None |
| 213 | + args = json.loads(tool["arguments"].replace("'", '"')) |
| 214 | + assert "cls" not in args |
| 215 | + assert "param" in args |
| 216 | + |
| 217 | +def test_generate_tools_json_doc_with_no_kernel_functions(): |
| 218 | + class Dummy: |
| 219 | + agent_name = "dummy" |
| 220 | + @classmethod |
| 221 | + def get_all_kernel_functions(cls): |
| 222 | + return [] |
| 223 | + @classmethod |
| 224 | + def generate_tools_json_doc(cls): |
| 225 | + return json.dumps([]) |
| 226 | + parsed = json.loads(Dummy.generate_tools_json_doc()) |
| 227 | + assert parsed == [] |
| 228 | + |
| 229 | +def test_generate_tools_json_doc_sets_agent_name(): |
| 230 | + class CustomAgent(GenericTools): |
| 231 | + agent_name = "custom_agent" |
| 232 | + sample = staticmethod(sample) |
| 233 | + parsed = json.loads(CustomAgent.generate_tools_json_doc()) |
| 234 | + tool = next((t for t in parsed if t["function"] == "sample"), None) |
| 235 | + assert tool is not None |
| 236 | + assert tool["agent"] == "custom_agent" |
| 237 | + |
| 238 | +def test_generate_tools_json_doc_handles_annotated_type(): |
| 239 | + GenericTools.annotated_param = staticmethod(annotated_param) |
| 240 | + parsed = json.loads(GenericTools.generate_tools_json_doc()) |
| 241 | + tool = next((t for t in parsed if t["function"] == "annotated_param"), None) |
| 242 | + assert tool is not None |
| 243 | + args = json.loads(tool["arguments"].replace("'", '"')) |
| 244 | + assert args["x"]["type"] == "int" |
| 245 | + |
| 246 | +def test_generate_tools_json_doc_multiple_functions(): |
| 247 | + class Dummy(GenericTools): |
| 248 | + agent_name = "dummy" |
| 249 | + @kernel_function(description="Add numbers") |
| 250 | + async def add(self, a: int, b: int) -> int: |
| 251 | + return a + b |
| 252 | + @kernel_function(description="Concat strings") |
| 253 | + async def concat(self, x: str, y: str) -> str: |
| 254 | + return x + y |
| 255 | + parsed = json.loads(Dummy.generate_tools_json_doc()) |
| 256 | + assert any(tool["function"] == "add" for tool in parsed) |
| 257 | + assert any(tool["function"] == "concat" for tool in parsed) |
| 258 | + assert all(tool["agent"] == "dummy" for tool in parsed) |
0 commit comments