|
17 | 17 |
|
18 | 18 | import pytest |
19 | 19 | from aiohttp import ClientSession |
| 20 | +from toolbox_core.client import ToolboxClient as ToolboxCoreClient |
20 | 21 | from toolbox_core.protocol import ManifestSchema |
| 22 | +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema |
| 23 | +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool |
| 24 | +from toolbox_core.utils import params_to_pydantic_model |
21 | 25 |
|
22 | 26 | from toolbox_langchain.async_client import AsyncToolboxClient |
23 | 27 | from toolbox_langchain.async_tools import AsyncToolboxTool |
@@ -60,123 +64,200 @@ def manifest_schema(self): |
60 | 64 | def mock_session(self): |
61 | 65 | return AsyncMock(spec=ClientSession) |
62 | 66 |
|
| 67 | + @pytest.fixture |
| 68 | + def mock_core_client_instance(self, manifest_schema, mock_session): |
| 69 | + mock = AsyncMock(spec=ToolboxCoreClient) |
| 70 | + |
| 71 | + async def mock_load_tool_impl(name, auth_token_getters, bound_params): |
| 72 | + tool_schema_dict = MANIFEST_JSON["tools"].get(name) |
| 73 | + if not tool_schema_dict: |
| 74 | + raise ValueError(f"Tool '{name}' not in mock manifest_dict") |
| 75 | + |
| 76 | + core_params = [ |
| 77 | + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] |
| 78 | + ] |
| 79 | + # Return a mock that looks like toolbox_core.tool.ToolboxTool |
| 80 | + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) |
| 81 | + core_tool_mock.__name__ = name |
| 82 | + core_tool_mock.__doc__ = tool_schema_dict["description"] |
| 83 | + core_tool_mock._pydantic_model = params_to_pydantic_model(name, core_params) |
| 84 | + # Add other necessary attributes or method mocks if AsyncToolboxTool uses them |
| 85 | + return core_tool_mock |
| 86 | + |
| 87 | + mock.load_tool = AsyncMock(side_effect=mock_load_tool_impl) |
| 88 | + |
| 89 | + async def mock_load_toolset_impl( |
| 90 | + name, auth_token_getters, bound_params, strict |
| 91 | + ): |
| 92 | + core_tools_list = [] |
| 93 | + for tool_name_iter, tool_schema_dict in MANIFEST_JSON["tools"].items(): |
| 94 | + core_params = [ |
| 95 | + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] |
| 96 | + ] |
| 97 | + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) |
| 98 | + core_tool_mock.__name__ = tool_name_iter |
| 99 | + core_tool_mock.__doc__ = tool_schema_dict["description"] |
| 100 | + core_tool_mock._pydantic_model = params_to_pydantic_model( |
| 101 | + tool_name_iter, core_params |
| 102 | + ) |
| 103 | + core_tools_list.append(core_tool_mock) |
| 104 | + return core_tools_list |
| 105 | + |
| 106 | + mock.load_toolset = AsyncMock(side_effect=mock_load_toolset_impl) |
| 107 | + # Mock the session attribute if it's directly accessed by AsyncToolboxClient tests |
| 108 | + mock._ToolboxClient__session = mock_session |
| 109 | + return mock |
| 110 | + |
63 | 111 | @pytest.fixture() |
64 | | - def mock_client(self, mock_session): |
65 | | - return AsyncToolboxClient(URL, session=mock_session) |
| 112 | + def mock_client(self, mock_session, mock_core_client_instance): |
| 113 | + # Patch the ToolboxCoreClient constructor used by AsyncToolboxClient |
| 114 | + with patch( |
| 115 | + "toolbox_langchain.async_client.ToolboxCoreClient", |
| 116 | + return_value=mock_core_client_instance, |
| 117 | + ): |
| 118 | + client = AsyncToolboxClient(URL, session=mock_session) |
| 119 | + # Ensure the mocked core client is used |
| 120 | + client._AsyncToolboxClient__core_client = mock_core_client_instance |
| 121 | + return client |
66 | 122 |
|
67 | 123 | async def test_create_with_existing_session(self, mock_client, mock_session): |
68 | | - assert mock_client._AsyncToolboxClient__session == mock_session |
| 124 | + # AsyncToolboxClient stores the core_client, which stores the session |
| 125 | + assert ( |
| 126 | + mock_client._AsyncToolboxClient__core_client._ToolboxClient__session |
| 127 | + == mock_session |
| 128 | + ) |
69 | 129 |
|
70 | | - @patch("toolbox_langchain.async_client._load_manifest") |
71 | 130 | async def test_aload_tool( |
72 | | - self, mock_load_manifest, mock_client, mock_session, manifest_schema |
| 131 | + self, |
| 132 | + mock_client, |
| 133 | + manifest_schema, # mock_session removed as it's part of mock_core_client_instance |
73 | 134 | ): |
74 | 135 | tool_name = "test_tool_1" |
75 | | - mock_load_manifest.return_value = manifest_schema |
| 136 | + # manifest_schema is used by mock_core_client_instance fixture to provide tool details |
76 | 137 |
|
77 | 138 | tool = await mock_client.aload_tool(tool_name) |
78 | 139 |
|
79 | | - mock_load_manifest.assert_called_once_with( |
80 | | - f"{URL}/api/tool/{tool_name}", mock_session |
| 140 | + # Assert that the core client's load_tool was called correctly |
| 141 | + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( |
| 142 | + name=tool_name, auth_token_getters={}, bound_params={} |
81 | 143 | ) |
82 | 144 | assert isinstance(tool, AsyncToolboxTool) |
83 | | - assert tool.name == tool_name |
| 145 | + assert ( |
| 146 | + tool.name == tool_name |
| 147 | + ) # AsyncToolboxTool gets its name from the core_tool |
84 | 148 |
|
85 | | - @patch("toolbox_langchain.async_client._load_manifest") |
86 | 149 | async def test_aload_tool_auth_headers_deprecated( |
87 | | - self, mock_load_manifest, mock_client, manifest_schema |
| 150 | + self, mock_client, manifest_schema |
88 | 151 | ): |
89 | 152 | tool_name = "test_tool_1" |
90 | | - mock_manifest = manifest_schema |
91 | | - mock_load_manifest.return_value = mock_manifest |
| 153 | + auth_lambda = lambda: "Bearer token" # Define lambda once |
92 | 154 | with catch_warnings(record=True) as w: |
93 | 155 | simplefilter("always") |
94 | 156 | await mock_client.aload_tool( |
95 | | - tool_name, auth_headers={"Authorization": lambda: "Bearer token"} |
| 157 | + tool_name, |
| 158 | + auth_headers={"Authorization": auth_lambda}, # Use the defined lambda |
96 | 159 | ) |
97 | 160 | assert len(w) == 1 |
98 | 161 | assert issubclass(w[-1].category, DeprecationWarning) |
99 | 162 | assert "auth_headers" in str(w[-1].message) |
100 | 163 |
|
101 | | - @patch("toolbox_langchain.async_client._load_manifest") |
| 164 | + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( |
| 165 | + name=tool_name, |
| 166 | + auth_token_getters={"Authorization": auth_lambda}, |
| 167 | + bound_params={}, |
| 168 | + ) |
| 169 | + |
102 | 170 | async def test_aload_tool_auth_headers_and_tokens( |
103 | | - self, mock_load_manifest, mock_client, manifest_schema |
| 171 | + self, mock_client, manifest_schema |
104 | 172 | ): |
105 | 173 | tool_name = "test_tool_1" |
106 | | - mock_manifest = manifest_schema |
107 | | - mock_load_manifest.return_value = mock_manifest |
| 174 | + auth_getters = {"test": lambda: "token"} |
| 175 | + auth_headers_lambda = lambda: "Bearer token" # Define lambda once |
| 176 | + |
108 | 177 | with catch_warnings(record=True) as w: |
109 | 178 | simplefilter("always") |
110 | 179 | await mock_client.aload_tool( |
111 | 180 | tool_name, |
112 | | - auth_headers={"Authorization": lambda: "Bearer token"}, |
113 | | - auth_token_getters={"test": lambda: "token"}, |
| 181 | + auth_headers={ |
| 182 | + "Authorization": auth_headers_lambda |
| 183 | + }, # Use defined lambda |
| 184 | + auth_token_getters=auth_getters, |
114 | 185 | ) |
115 | | - assert len(w) == 1 |
| 186 | + assert ( |
| 187 | + len(w) == 1 |
| 188 | + ) # Only one warning because auth_token_getters takes precedence |
116 | 189 | assert issubclass(w[-1].category, DeprecationWarning) |
117 | | - assert "auth_headers" in str(w[-1].message) |
| 190 | + assert "auth_headers" in str(w[-1].message) # Warning for auth_headers |
| 191 | + |
| 192 | + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( |
| 193 | + name=tool_name, auth_token_getters=auth_getters, bound_params={} |
| 194 | + ) |
118 | 195 |
|
119 | | - @patch("toolbox_langchain.async_client._load_manifest") |
120 | 196 | async def test_aload_toolset( |
121 | | - self, mock_load_manifest, mock_client, mock_session, manifest_schema |
| 197 | + self, mock_client, manifest_schema # mock_session removed |
122 | 198 | ): |
123 | | - mock_manifest = manifest_schema |
124 | | - mock_load_manifest.return_value = mock_manifest |
125 | 199 | tools = await mock_client.aload_toolset() |
126 | 200 |
|
127 | | - mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session) |
128 | | - assert len(tools) == 2 |
| 201 | + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( |
| 202 | + name=None, auth_token_getters={}, bound_params={}, strict=False |
| 203 | + ) |
| 204 | + assert len(tools) == 2 # Based on MANIFEST_JSON |
129 | 205 | for tool in tools: |
130 | 206 | assert isinstance(tool, AsyncToolboxTool) |
131 | 207 | assert tool.name in ["test_tool_1", "test_tool_2"] |
132 | 208 |
|
133 | | - @patch("toolbox_langchain.async_client._load_manifest") |
134 | 209 | async def test_aload_toolset_with_toolset_name( |
135 | | - self, mock_load_manifest, mock_client, mock_session, manifest_schema |
| 210 | + self, mock_client, manifest_schema # mock_session removed |
136 | 211 | ): |
137 | | - toolset_name = "test_toolset_1" |
138 | | - mock_manifest = manifest_schema |
139 | | - mock_load_manifest.return_value = mock_manifest |
| 212 | + toolset_name = "test_toolset_1" # This name isn't in MANIFEST_JSON, but load_toolset mock doesn't filter by it |
140 | 213 | tools = await mock_client.aload_toolset(toolset_name=toolset_name) |
141 | 214 |
|
142 | | - mock_load_manifest.assert_called_once_with( |
143 | | - f"{URL}/api/toolset/{toolset_name}", mock_session |
| 215 | + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( |
| 216 | + name=toolset_name, auth_token_getters={}, bound_params={}, strict=False |
144 | 217 | ) |
145 | 218 | assert len(tools) == 2 |
146 | 219 | for tool in tools: |
147 | 220 | assert isinstance(tool, AsyncToolboxTool) |
148 | 221 | assert tool.name in ["test_tool_1", "test_tool_2"] |
149 | 222 |
|
150 | | - @patch("toolbox_langchain.async_client._load_manifest") |
151 | 223 | async def test_aload_toolset_auth_headers_deprecated( |
152 | | - self, mock_load_manifest, mock_client, manifest_schema |
| 224 | + self, mock_client, manifest_schema |
153 | 225 | ): |
154 | | - mock_manifest = manifest_schema |
155 | | - mock_load_manifest.return_value = mock_manifest |
| 226 | + auth_lambda = lambda: "Bearer token" # Define lambda once |
156 | 227 | with catch_warnings(record=True) as w: |
157 | 228 | simplefilter("always") |
158 | 229 | await mock_client.aload_toolset( |
159 | | - auth_headers={"Authorization": lambda: "Bearer token"} |
| 230 | + auth_headers={"Authorization": auth_lambda} # Use defined lambda |
160 | 231 | ) |
161 | 232 | assert len(w) == 1 |
162 | 233 | assert issubclass(w[-1].category, DeprecationWarning) |
163 | 234 | assert "auth_headers" in str(w[-1].message) |
| 235 | + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( |
| 236 | + name=None, |
| 237 | + auth_token_getters={"Authorization": auth_lambda}, |
| 238 | + bound_params={}, |
| 239 | + strict=False, |
| 240 | + ) |
164 | 241 |
|
165 | | - @patch("toolbox_langchain.async_client._load_manifest") |
166 | 242 | async def test_aload_toolset_auth_headers_and_tokens( |
167 | | - self, mock_load_manifest, mock_client, manifest_schema |
| 243 | + self, mock_client, manifest_schema |
168 | 244 | ): |
169 | | - mock_manifest = manifest_schema |
170 | | - mock_load_manifest.return_value = mock_manifest |
| 245 | + auth_getters = {"test": lambda: "token"} |
| 246 | + auth_headers_lambda = lambda: "Bearer token" # Define lambda once |
171 | 247 | with catch_warnings(record=True) as w: |
172 | 248 | simplefilter("always") |
173 | 249 | await mock_client.aload_toolset( |
174 | | - auth_headers={"Authorization": lambda: "Bearer token"}, |
175 | | - auth_token_getters={"test": lambda: "token"}, |
| 250 | + auth_headers={ |
| 251 | + "Authorization": auth_headers_lambda |
| 252 | + }, # Use defined lambda |
| 253 | + auth_token_getters=auth_getters, |
176 | 254 | ) |
177 | 255 | assert len(w) == 1 |
178 | 256 | assert issubclass(w[-1].category, DeprecationWarning) |
179 | 257 | assert "auth_headers" in str(w[-1].message) |
| 258 | + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( |
| 259 | + name=None, auth_token_getters=auth_getters, bound_params={}, strict=False |
| 260 | + ) |
180 | 261 |
|
181 | 262 | async def test_load_tool_not_implemented(self, mock_client): |
182 | 263 | with pytest.raises(NotImplementedError) as excinfo: |
|
0 commit comments