-
-
Notifications
You must be signed in to change notification settings - Fork 78
Expand file tree
/
Copy pathtest_tool_manager.py
More file actions
820 lines (650 loc) · 25.1 KB
/
test_tool_manager.py
File metadata and controls
820 lines (650 loc) · 25.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
from unittest.mock import Mock
import pytest
from mesa_llm.tools.tool_decorator import _GLOBAL_TOOL_REGISTRY, tool
from mesa_llm.tools.tool_manager import ToolManager
class TestToolManager:
def setup_method(self):
"""Set up test fixtures before each test method."""
# Clear global registry to start fresh
_GLOBAL_TOOL_REGISTRY.clear()
# Clear instances list
ToolManager.instances.clear()
def teardown_method(self):
"""Clean up after each test method."""
_GLOBAL_TOOL_REGISTRY.clear()
ToolManager.instances.clear()
def test_init_empty(self):
"""Test initialization with no tools."""
manager = ToolManager()
assert isinstance(manager.tools, dict)
assert len(manager.tools) == 0
assert manager in ToolManager.instances
def test_init_with_global_tools(self):
"""Test initialization with global tools."""
# Register a tool globally first
@tool
def test_global_tool(agent, param1: str) -> str:
"""Test global tool.
Args:
agent: The agent making the request (provided automatically)
param1: A test parameter.
Returns:
The input parameter.
"""
return param1
manager = ToolManager()
assert "test_global_tool" in manager.tools
assert manager.tools["test_global_tool"] == test_global_tool
def test_init_with_extra_tools(self):
"""Test initialization with extra tools."""
def extra_tool(agent, x: int) -> int:
"""Extra tool.
Args:
agent: The agent making the request (provided automatically)
x: Input number.
Returns:
The input number.
"""
return x
extra_tool.__tool_schema__ = {
"type": "function",
"function": {
"name": "extra_tool",
"description": "Extra tool returns: The input number.",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "integer", "description": "Input number."}
},
"required": ["x"],
},
},
}
extra_tools = {"extra_tool": extra_tool}
manager = ToolManager(extra_tools=extra_tools)
assert "extra_tool" in manager.tools
def test_register_tool(self):
"""Test registering a tool manually."""
manager = ToolManager()
def manual_tool(agent, text: str) -> str:
"""Manual tool.
Args:
agent: The agent making the request (provided automatically)
text: Input text.
Returns:
The input text.
"""
return text
manager.register(manual_tool)
assert "manual_tool" in manager.tools
assert manager.tools["manual_tool"] == manual_tool
def test_add_tool_to_all(self):
"""Test adding a tool to all manager instances."""
manager1 = ToolManager()
manager2 = ToolManager()
def shared_tool(agent, value: str) -> str:
"""Shared tool.
Args:
agent: The agent making the request (provided automatically)
value: Input value.
Returns:
The input value.
"""
return value
ToolManager.add_tool_to_all(shared_tool)
assert "shared_tool" in manager1.tools
assert "shared_tool" in manager2.tools
def test_get_tool_schema(self):
"""Test getting tool schema."""
manager = ToolManager()
@tool
def schema_test_tool(agent, param: str) -> str:
"""Schema test tool.
Args:
agent: The agent making the request (provided automatically)
param: A parameter.
Returns:
The parameter.
"""
return param
manager.register(schema_test_tool)
schema = manager.get_tool_schema(schema_test_tool, "schema_test_tool")
assert "type" in schema
assert "function" in schema
assert schema["function"]["name"] == "schema_test_tool"
def test_get_tool_schema_missing(self):
"""Test getting schema for tool without schema."""
manager = ToolManager()
def no_schema_tool():
return "test"
schema = manager.get_tool_schema(no_schema_tool, "no_schema_tool")
assert "error" in schema
def test_get_all_tools_schema(self):
"""Test getting all tools schemas."""
@tool
def tool1(agent, x: int) -> int:
"""Tool 1.
Args:
agent: The agent making the request (provided automatically)
x: Input.
Returns:
Output.
"""
return x
@tool
def tool2(agent, y: str) -> str:
"""Tool 2.
Args:
agent: The agent making the request (provided automatically)
y: Input.
Returns:
Output.
"""
return y
manager = ToolManager()
schemas = manager.get_all_tools_schema()
assert len(schemas) == 2
assert all("function" in schema for schema in schemas)
def test_get_all_tools_schema_with_selected_tools(self):
"""Test getting schemas for selected tools only."""
@tool
def tool_a(agent, x: int) -> int:
"""Tool A.
Args:
agent: The agent making the request (provided automatically)
x: Input.
Returns:
Output.
"""
return x
@tool
def tool_b(agent, y: str) -> str:
"""Tool B.
Args:
agent: The agent making the request (provided automatically)
y: Input.
Returns:
Output.
"""
return y
@tool
def tool_c(agent, z: float) -> float:
"""Tool C.
Args:
agent: The agent making the request (provided automatically)
z: Input.
Returns:
Output.
"""
return z
manager = ToolManager()
# Test selecting specific tools
selected_tools = ["tool_a", "tool_c"]
schemas = manager.get_all_tools_schema(selected_tools)
assert len(schemas) == 2
tool_names = [schema["function"]["name"] for schema in schemas]
assert "tool_a" in tool_names
assert "tool_c" in tool_names
assert "tool_b" not in tool_names
def test_get_all_tools_schema_empty_list(self):
"""Test that empty list returns all tools (current behavior)."""
@tool
def test_tool(agent, x: int) -> int:
"""Test tool.
Args:
agent: The agent making the request (provided automatically)
x: Input.
Returns:
Output.
"""
return x
manager = ToolManager()
# Empty list should return all tools (current behavior)
all_schemas = manager.get_all_tools_schema()
empty_list_schemas = manager.get_all_tools_schema([])
assert len(empty_list_schemas) == len(all_schemas)
def test_get_all_tools_schema_none(self):
"""Test that None returns all tools."""
@tool
def test_tool(agent, x: int) -> int:
"""Test tool.
Args:
agent: The agent making the request (provided automatically)
x: Input.
Returns:
Output.
"""
return x
manager = ToolManager()
all_schemas = manager.get_all_tools_schema()
none_schemas = manager.get_all_tools_schema(None)
assert len(none_schemas) == len(all_schemas)
def test_get_all_tools_schema_nonexistent_tools(self):
"""Test that requesting nonexistent tools raises appropriate errors."""
@tool
def existing_tool(agent, x: int) -> int:
"""Existing tool.
Args:
agent: The agent making the request (provided automatically)
x: Input.
Returns:
Output.
"""
return x
manager = ToolManager()
# Test with nonexistent tools
selected_tools = ["existing_tool", "nonexistent_tool"]
with pytest.raises(KeyError):
manager.get_all_tools_schema(selected_tools)
def test_get_all_tools_schema_single_tool(self):
"""Test selecting a single tool."""
@tool
def single_tool(agent, x: int) -> int:
"""Single tool.
Args:
agent: The agent making the request (provided automatically)
x: Input.
Returns:
Output.
"""
return x
@tool
def other_tool(agent, y: str) -> str:
"""Other tool.
Args:
agent: The agent making the request (provided automatically)
y: Input.
Returns:
Output.
"""
return y
manager = ToolManager()
schemas = manager.get_all_tools_schema(["single_tool"])
assert len(schemas) == 1
assert schemas[0]["function"]["name"] == "single_tool"
def test_call_tool_success(self):
"""Test successfully calling a tool."""
manager = ToolManager()
def callable_tool(agent, message: str) -> str:
"""Callable tool.
Args:
agent: The agent making the request (provided automatically)
message: Input message.
Returns:
The message with prefix.
"""
return f"Result: {message}"
manager.register(callable_tool)
result = manager.call("callable_tool", {"agent": Mock(), "message": "test"})
assert result == "Result: test"
def test_call_tool_not_found(self):
"""Test calling a non-existent tool."""
manager = ToolManager()
with pytest.raises(ValueError, match="Tool 'nonexistent' not found"):
manager.call("nonexistent", {})
def test_has_tool(self):
"""Test checking if a tool exists."""
manager = ToolManager()
def existing_tool():
return "test"
manager.register(existing_tool)
assert manager.has_tool("existing_tool") is True
assert manager.has_tool("nonexistent_tool") is False
def test_call_tools_no_tool_calls(self):
"""Test call_tools with response that has no tool_calls."""
manager = ToolManager()
mock_agent = Mock()
mock_response = Mock()
mock_response.tool_calls = None
result = manager.call_tools(mock_agent, mock_response)
assert result == []
def test_call_tools_empty_tool_calls(self):
"""Test call_tools with empty tool_calls list."""
manager = ToolManager()
mock_agent = Mock()
mock_response = Mock()
mock_response.tool_calls = []
result = manager.call_tools(mock_agent, mock_response)
assert result == []
def test_call_tools_success(self):
"""Test successful tool calling."""
manager = ToolManager()
@tool
def test_tool(agent, param1: str) -> str:
"""Test tool for call_tools.
Args:
agent: The agent making the request (provided automatically)
param1: Test parameter.
Returns:
Processed parameter.
"""
return f"Processed: {param1}"
# Mock agent
mock_agent = Mock()
# Mock LLM response with tool calls
mock_tool_call = Mock()
mock_tool_call.id = "call_123"
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = '{"param1": "test_value"}'
mock_response = Mock()
mock_response.tool_calls = [mock_tool_call]
result = manager.call_tools(mock_agent, mock_response)
assert len(result) == 1
assert result[0]["tool_call_id"] == "call_123"
assert result[0]["role"] == "tool"
assert result[0]["name"] == "test_tool"
assert "Processed: test_value" in result[0]["response"]
def test_call_tools_function_not_found(self):
"""Test call_tools with non-existent function."""
manager = ToolManager()
mock_agent = Mock()
mock_tool_call = Mock()
mock_tool_call.id = "call_123"
mock_tool_call.function.name = "nonexistent_function"
mock_tool_call.function.arguments = '{"param": "value"}'
mock_response = Mock()
mock_response.tool_calls = [mock_tool_call]
result = manager.call_tools(mock_agent, mock_response)
assert len(result) == 1
assert result[0]["tool_call_id"] == "call_123"
assert "Error:" in result[0]["response"]
def test_call_tools_invalid_json(self):
"""Test call_tools with invalid JSON arguments."""
manager = ToolManager()
@tool
def test_tool(agent, param1: str) -> str:
"""Test tool.
Args:
agent: The agent making the request (provided automatically)
param1: Test parameter.
Returns:
Processed parameter.
"""
return f"Processed: {param1}"
mock_agent = Mock()
mock_tool_call = Mock()
mock_tool_call.id = "call_123"
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = '{"param1": invalid_json}'
mock_response = Mock()
mock_response.tool_calls = [mock_tool_call]
result = manager.call_tools(mock_agent, mock_response)
assert len(result) == 1
assert result[0]["tool_call_id"] == "call_123"
assert "Error:" in result[0]["response"]
def test_call_tools_successful_argument_filtering(self):
"""Test call_tools with argument filtering when function signature doesn't match."""
manager = ToolManager()
def simple_tool(required_param: str) -> str:
"""Simple tool that only takes required_param.
Args:
required_param: The only parameter this function accepts.
Returns:
Processed parameter.
"""
return f"Simple: {required_param}"
# Register tool manually without using decorator to test filtering
manager.register(simple_tool)
mock_agent = Mock()
mock_tool_call = Mock()
mock_tool_call.id = "call_123"
mock_tool_call.function.name = "simple_tool"
# Include extra parameters that the function doesn't accept
mock_tool_call.function.arguments = (
'{"required_param": "test", "extra_param": "ignored"}'
)
mock_response = Mock()
mock_response.tool_calls = [mock_tool_call]
result = manager.call_tools(mock_agent, mock_response)
assert len(result) == 1
assert result[0]["tool_call_id"] == "call_123"
assert "Simple: test" in result[0]["response"]
def test_call_tools_no_response(self):
"""Test call_tools when tool returns None."""
manager = ToolManager()
@tool
def silent_tool(agent) -> None:
"""Tool that returns None.
Args:
agent: The agent making the request (provided automatically)
Returns:
None
"""
return None
mock_agent = Mock()
mock_tool_call = Mock()
mock_tool_call.id = "call_123"
mock_tool_call.function.name = "silent_tool"
mock_tool_call.function.arguments = "{}"
mock_response = Mock()
mock_response.tool_calls = [mock_tool_call]
result = manager.call_tools(mock_agent, mock_response)
assert len(result) == 1
assert result[0]["tool_call_id"] == "call_123"
assert "silent_tool executed successfully" in result[0]["response"]
def test_call_tools_general_exception(self):
"""Test call_tools handling of general exceptions."""
manager = ToolManager()
mock_agent = Mock()
# Create a mock response that will cause an AttributeError
mock_response = Mock()
# Remove the tool_calls attribute to cause AttributeError
del mock_response.tool_calls
result = manager.call_tools(mock_agent, mock_response)
assert result == []
def test_selected_tools_consistency(self):
"""Test that selected_tools parameter works consistently."""
@tool
def consistency_tool_a(agent, x: int) -> int:
"""Consistency tool A.
Args:
agent: The agent making the request (provided automatically)
x: Input.
Returns:
Output.
"""
return x
@tool
def consistency_tool_b(agent, y: str) -> str:
"""Consistency tool B.
Args:
agent: The agent making the request (provided automatically)
y: Input.
Returns:
Output.
"""
return y
@tool
def consistency_tool_c(agent, z: float) -> float:
"""Consistency tool C.
Args:
agent: The agent making the request (provided automatically)
z: Input.
Returns:
Output.
"""
return z
manager = ToolManager()
# Test that same selected_tools always returns same schemas
selected_tools = ["consistency_tool_a", "consistency_tool_c"]
schemas1 = manager.get_all_tools_schema(selected_tools)
schemas2 = manager.get_all_tools_schema(selected_tools)
names1 = sorted([schema["function"]["name"] for schema in schemas1])
names2 = sorted([schema["function"]["name"] for schema in schemas2])
assert names1 == names2
assert len(schemas1) == len(schemas2) == 2
# Test order independence
reversed_tools = list(reversed(selected_tools))
schemas3 = manager.get_all_tools_schema(reversed_tools)
names3 = sorted([schema["function"]["name"] for schema in schemas3])
assert names1 == names3
def test_selected_tools_duplicate_handling(self):
"""Test how selected_tools handles duplicates."""
@tool
def duplicate_test_tool(agent, x: int) -> int:
"""Duplicate test tool.
Args:
agent: The agent making the request (provided automatically)
x: Input.
Returns:
Output.
"""
return x
manager = ToolManager()
# Test with duplicate tool names
selected_tools = ["duplicate_test_tool", "duplicate_test_tool"]
schemas = manager.get_all_tools_schema(selected_tools)
# Should return schemas for each request (may include duplicates)
assert len(schemas) == 2
def test_multiple_managers_selected_tools(self):
"""Test selected_tools functionality with multiple ToolManager instances."""
@tool
def shared_tool_1(agent, x: int) -> int:
"""Shared tool 1.
Args:
agent: The agent making the request (provided automatically)
x: Input.
Returns:
Output.
"""
return x
@tool
def shared_tool_2(agent, y: str) -> str:
"""Shared tool 2.
Args:
agent: The agent making the request (provided automatically)
y: Input.
Returns:
Output.
"""
return y
manager1 = ToolManager()
manager2 = ToolManager()
# Both managers should have the same tools
all_schemas_1 = manager1.get_all_tools_schema()
all_schemas_2 = manager2.get_all_tools_schema()
assert len(all_schemas_1) == len(all_schemas_2)
# Selected tools should work the same on both managers
selected_tools = ["shared_tool_1"]
schemas_1 = manager1.get_all_tools_schema(selected_tools)
schemas_2 = manager2.get_all_tools_schema(selected_tools)
assert len(schemas_1) == len(schemas_2) == 1
assert schemas_1[0]["function"]["name"] == schemas_2[0]["function"]["name"]
@pytest.mark.asyncio
async def test_acall_tools_success(self, monkeypatch):
"""
This test validates the full async tool execution pipeline by ensuring that:
- Tool calls are correctly extracted from the LLM response object.
- JSON-formatted arguments are parsed without error.
- The `agent` parameter is automatically injected when required by the tool's function signature.
- The tool function is executed successfully.
- The result is wrapped in the expected structured response format:
{
"tool_call_id": <tool_call_id>,
"role": "tool",
"name": <tool_name>,
"response": <stringified_result>
}
- The asynchronous execution path using `asyncio.gather`
returns the correct results.
"""
manager = ToolManager()
mock_agent = Mock()
@tool
def async_test_tool(agent, value: str) -> str:
"""Async test tool.
Args:
agent: The agent making the request (provided automatically)
value: Input value.
Returns:
Processed value.
"""
return f"Async: {value}"
mock_tool_call = Mock()
mock_tool_call.id = "call_async"
mock_tool_call.function.name = "async_test_tool"
mock_tool_call.function.arguments = '{"value": "hello"}'
mock_response = Mock()
mock_response.tool_calls = [mock_tool_call]
result = await manager.acall_tools(mock_agent, mock_response)
assert len(result) == 1
assert result[0]["tool_call_id"] == "call_async"
assert "Async: hello" in result[0]["response"]
# ------------------------------------------------------------------
# Tests for include_builtins and remove_tool (issue #90)
# ------------------------------------------------------------------
def test_include_builtins_false_starts_empty(self):
"""ToolManager(include_builtins=False) should have no global tools."""
# Register some global tools first
@tool
def global_tool_a(agent, x: int) -> int:
"""Global tool A.
Args:
agent: Provided automatically.
x: Input.
Returns:
Output.
"""
return x
manager = ToolManager(include_builtins=False)
assert len(manager.tools) == 0
assert not manager.has_tool("global_tool_a")
def test_include_builtins_true_includes_global_tools(self):
"""Default ToolManager should include globally registered tools."""
@tool
def global_tool_b(agent, x: int) -> int:
"""Global tool B.
Args:
agent: Provided automatically.
x: Input.
Returns:
Output.
"""
return x
manager = ToolManager(include_builtins=True)
assert manager.has_tool("global_tool_b")
def test_include_builtins_false_with_extra_tools(self):
"""include_builtins=False with extra_tools should only have extras."""
@tool
def global_should_not_appear(agent, x: int) -> int:
"""Should not appear.
Args:
agent: Provided automatically.
x: Input.
Returns:
Output.
"""
return x
def custom_tool(agent, y: str) -> str:
return y
custom_tool.__tool_schema__ = {
"type": "function",
"function": {"name": "custom_tool"},
}
manager = ToolManager(
include_builtins=False,
extra_tools={"custom_tool": custom_tool},
)
assert manager.has_tool("custom_tool")
assert not manager.has_tool("global_should_not_appear")
assert len(manager.tools) == 1
def test_remove_tool(self):
"""remove_tool should remove a tool by name."""
@tool
def removable_tool(agent, x: int) -> int:
"""Removable tool.
Args:
agent: Provided automatically.
x: Input.
Returns:
Output.
"""
return x
manager = ToolManager()
assert manager.has_tool("removable_tool")
manager.remove_tool("removable_tool")
assert not manager.has_tool("removable_tool")
def test_remove_tool_missing_is_silent(self):
"""remove_tool should not raise on missing tool names."""
manager = ToolManager()
manager.remove_tool("nonexistent_tool") # should not raise