22
33from mellea .backends import Backend
44from mellea .backends .ollama import OllamaModelBackend
5- from mellea .backends .tools import add_tools_from_model_options
5+ from mellea .backends .tools import add_tools_from_context_actions , add_tools_from_model_options
66from mellea .backends .types import ModelOption
7- from mellea .stdlib .base import ModelOutputThunk
7+ from mellea .stdlib .base import CBlock , Component , ModelOutputThunk , TemplateRepresentation
88from mellea .stdlib .docs .richdocument import Table
99from mellea .stdlib .session import LinearContext , MelleaSession
1010
@@ -29,13 +29,47 @@ def table() -> Table:
2929 assert t is not None , "test setup failed: could not create table from markdown"
3030 return t
3131
32- def test_add_tools_from_model_options (table : Table ):
32+ class FakeToolComponent (Component ):
33+ def __init__ (self ) -> None :
34+ super ().__init__ ()
35+
36+ def tool1 (self ):
37+ return
38+
39+ def parts (self ):
40+ return []
41+
42+ def format_for_llm (self ) -> TemplateRepresentation :
43+ return TemplateRepresentation (
44+ obj = self ,
45+ args = {"arg" : None },
46+ tools = {
47+ self .tool1 .__name__ : self .tool1
48+ }
49+ )
50+
51+ class FakeToolComponentWithExtraTool (FakeToolComponent ):
52+ def __init__ (self ) -> None :
53+ super ().__init__ ()
54+
55+ def tool2 (self ):
56+ return
57+
58+ def format_for_llm (self ) -> TemplateRepresentation :
59+ tr = super ().format_for_llm ()
60+ assert tr .tools is not None
61+ tr .tools [self .tool2 .__name__ ] = self .tool2
62+ return tr
63+
64+
65+ def test_add_tools_from_model_options_list ():
3366 def get_weather (location : str ) -> int :
3467 """Returns the weather in Celsius."""
3568 return 21
3669
70+ ftc = FakeToolComponent ()
3771 model_options = {
38- ModelOption .TOOLS : [get_weather , table . content_as_string ]
72+ ModelOption .TOOLS : [get_weather , ftc . tool1 ]
3973 }
4074
4175 tools = {}
@@ -44,36 +78,75 @@ def get_weather(location: str) -> int:
4478 assert tools ["get_weather" ] == get_weather
4579
4680 # Must use `==` for bound methods.
47- assert tools ["content_as_string " ] == table . content_as_string , f"{ tools ["content_as_string " ]} is not { table . content_as_string } "
81+ assert tools ["tool1 " ] == ftc . tool1 , f"{ tools ["tool1 " ]} should == { ftc . tool1 } "
4882
49- def test_tool_called (m : MelleaSession , table : Table ):
50- """We don't force tools to be called. As a result, this test might unexpectedly fail."""
51- r = 10
5283
53- returned_tool = False
54- for i in range (r ):
55- transformed = m .transform (table , "add a new row to this table" )
56- if isinstance (transformed , Table ):
57- returned_tool = True
58- break
84+ def test_add_tools_from_model_options_map ():
85+ def get_weather (location : str ) -> int :
86+ """Returns the weather in Celsius."""
87+ return 21
88+
89+ ftc = FakeToolComponent ()
90+ model_options = {
91+ ModelOption .TOOLS : {
92+ get_weather .__name__ : get_weather ,
93+ ftc .tool1 .__name__ : ftc .tool1
94+ }
95+ }
96+
97+ tools = {}
98+ add_tools_from_model_options (tools , model_options )
99+
100+ assert tools ["get_weather" ] == get_weather
101+
102+ # Must use `==` for bound methods.
103+ assert tools ["tool1" ] == ftc .tool1 , f"{ tools ["tool1" ]} should == { ftc .tool1 } "
104+
105+
106+ def test_add_tools_from_context_actions ():
107+
108+ ftc1 = FakeToolComponentWithExtraTool ()
109+ ftc2 = FakeToolComponent ()
110+
111+ ctx_actions = [CBlock ("Hello" ), ftc1 , ftc2 ]
112+ tools = {}
113+ add_tools_from_context_actions (tools , ctx_actions )
114+
115+ # Check that tools with the same name get properly overwritten in order of ctx.
116+ assert tools ["tool1" ] == ftc2 .tool1 , f"{ tools ["tool1" ]} should == { ftc2 .tool1 } "
117+
118+ # Check that tools that aren't overwritten are still there.
119+ assert tools ["tool2" ] == ftc1 .tool2 , f"{ tools ["tool2" ]} should == { ftc1 .tool2 } "
120+
121+
122+ # def test_tool_called(m: MelleaSession, table: Table):
123+ # """We don't force tools to be called. As a result, this test might unexpectedly fail."""
124+ # r = 10
125+
126+ # returned_tool = False
127+ # for i in range(r):
128+ # transformed = m.transform(table, "add a new row to this table")
129+ # if isinstance(transformed, Table):
130+ # returned_tool = True
131+ # break
59132
60- assert returned_tool , f"did not return a tool after { r } attempts"
133+ # assert returned_tool, f"did not return a tool after {r} attempts"
61134
62135
63- def test_tool_not_called (m : MelleaSession , table : Table ):
64- """Ensure tools aren't always called when provided."""
65- r = 10
136+ # def test_tool_not_called(m: MelleaSession, table: Table):
137+ # """Ensure tools aren't always called when provided."""
138+ # r = 10
66139
67- returned_no_tool = False
68- for i in range (r ):
69- transformed = m .transform (table , "output a text description of this table" )
70- if isinstance (transformed , ModelOutputThunk ):
71- returned_no_tool = True
72- break
140+ # returned_no_tool = False
141+ # for i in range(r):
142+ # transformed = m.transform(table, "output a text description of this table")
143+ # if isinstance(transformed, ModelOutputThunk):
144+ # returned_no_tool = True
145+ # break
73146
74- assert (
75- returned_no_tool
76- ), f"only returned tools after { r } attempts, should've returned a response with no tools"
147+ # assert (
148+ # returned_no_tool
149+ # ), f"only returned tools after {r} attempts, should've returned a response with no tools"
77150
78151if __name__ == "__main__" :
79152 pytest .main ([__file__ ])
0 commit comments