|
17 | 17 |
|
18 | 18 | import pytest
|
19 | 19 | from aiohttp import ClientSession
|
| 20 | +from toolbox_core.client import ToolboxClient as ToolboxCoreClient |
20 | 21 | from toolbox_core.protocol import ManifestSchema
|
| 22 | +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema |
| 23 | +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool |
| 24 | +from toolbox_core.utils import params_to_pydantic_model |
21 | 25 |
|
22 | 26 | from toolbox_langchain.async_client import AsyncToolboxClient
|
23 | 27 | from toolbox_langchain.async_tools import AsyncToolboxTool
|
@@ -60,123 +64,200 @@ def manifest_schema(self):
|
60 | 64 | def mock_session(self):
|
61 | 65 | return AsyncMock(spec=ClientSession)
|
62 | 66 |
|
| 67 | + @pytest.fixture |
| 68 | + def mock_core_client_instance(self, manifest_schema, mock_session): |
| 69 | + mock = AsyncMock(spec=ToolboxCoreClient) |
| 70 | + |
| 71 | + async def mock_load_tool_impl(name, auth_token_getters, bound_params): |
| 72 | + tool_schema_dict = MANIFEST_JSON["tools"].get(name) |
| 73 | + if not tool_schema_dict: |
| 74 | + raise ValueError(f"Tool '{name}' not in mock manifest_dict") |
| 75 | + |
| 76 | + core_params = [ |
| 77 | + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] |
| 78 | + ] |
| 79 | + # Return a mock that looks like toolbox_core.tool.ToolboxTool |
| 80 | + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) |
| 81 | + core_tool_mock.__name__ = name |
| 82 | + core_tool_mock.__doc__ = tool_schema_dict["description"] |
| 83 | + core_tool_mock._pydantic_model = params_to_pydantic_model(name, core_params) |
| 84 | + # Add other necessary attributes or method mocks if AsyncToolboxTool uses them |
| 85 | + return core_tool_mock |
| 86 | + |
| 87 | + mock.load_tool = AsyncMock(side_effect=mock_load_tool_impl) |
| 88 | + |
| 89 | + async def mock_load_toolset_impl( |
| 90 | + name, auth_token_getters, bound_params, strict |
| 91 | + ): |
| 92 | + core_tools_list = [] |
| 93 | + for tool_name_iter, tool_schema_dict in MANIFEST_JSON["tools"].items(): |
| 94 | + core_params = [ |
| 95 | + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] |
| 96 | + ] |
| 97 | + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) |
| 98 | + core_tool_mock.__name__ = tool_name_iter |
| 99 | + core_tool_mock.__doc__ = tool_schema_dict["description"] |
| 100 | + core_tool_mock._pydantic_model = params_to_pydantic_model( |
| 101 | + tool_name_iter, core_params |
| 102 | + ) |
| 103 | + core_tools_list.append(core_tool_mock) |
| 104 | + return core_tools_list |
| 105 | + |
| 106 | + mock.load_toolset = AsyncMock(side_effect=mock_load_toolset_impl) |
| 107 | + # Mock the session attribute if it's directly accessed by AsyncToolboxClient tests |
| 108 | + mock._ToolboxClient__session = mock_session |
| 109 | + return mock |
| 110 | + |
63 | 111 | @pytest.fixture()
|
64 |
| - def mock_client(self, mock_session): |
65 |
| - return AsyncToolboxClient(URL, session=mock_session) |
| 112 | + def mock_client(self, mock_session, mock_core_client_instance): |
| 113 | + # Patch the ToolboxCoreClient constructor used by AsyncToolboxClient |
| 114 | + with patch( |
| 115 | + "toolbox_langchain.async_client.ToolboxCoreClient", |
| 116 | + return_value=mock_core_client_instance, |
| 117 | + ): |
| 118 | + client = AsyncToolboxClient(URL, session=mock_session) |
| 119 | + # Ensure the mocked core client is used |
| 120 | + client._AsyncToolboxClient__core_client = mock_core_client_instance |
| 121 | + return client |
66 | 122 |
|
67 | 123 | async def test_create_with_existing_session(self, mock_client, mock_session):
|
68 |
| - assert mock_client._AsyncToolboxClient__session == mock_session |
| 124 | + # AsyncToolboxClient stores the core_client, which stores the session |
| 125 | + assert ( |
| 126 | + mock_client._AsyncToolboxClient__core_client._ToolboxClient__session |
| 127 | + == mock_session |
| 128 | + ) |
69 | 129 |
|
70 |
| - @patch("toolbox_langchain.async_client._load_manifest") |
71 | 130 | async def test_aload_tool(
|
72 |
| - self, mock_load_manifest, mock_client, mock_session, manifest_schema |
| 131 | + self, |
| 132 | + mock_client, |
| 133 | + manifest_schema, # mock_session removed as it's part of mock_core_client_instance |
73 | 134 | ):
|
74 | 135 | tool_name = "test_tool_1"
|
75 |
| - mock_load_manifest.return_value = manifest_schema |
| 136 | + # manifest_schema is used by mock_core_client_instance fixture to provide tool details |
76 | 137 |
|
77 | 138 | tool = await mock_client.aload_tool(tool_name)
|
78 | 139 |
|
79 |
| - mock_load_manifest.assert_called_once_with( |
80 |
| - f"{URL}/api/tool/{tool_name}", mock_session |
| 140 | + # Assert that the core client's load_tool was called correctly |
| 141 | + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( |
| 142 | + name=tool_name, auth_token_getters={}, bound_params={} |
81 | 143 | )
|
82 | 144 | assert isinstance(tool, AsyncToolboxTool)
|
83 |
| - assert tool.name == tool_name |
| 145 | + assert ( |
| 146 | + tool.name == tool_name |
| 147 | + ) # AsyncToolboxTool gets its name from the core_tool |
84 | 148 |
|
85 |
| - @patch("toolbox_langchain.async_client._load_manifest") |
86 | 149 | async def test_aload_tool_auth_headers_deprecated(
|
87 |
| - self, mock_load_manifest, mock_client, manifest_schema |
| 150 | + self, mock_client, manifest_schema |
88 | 151 | ):
|
89 | 152 | tool_name = "test_tool_1"
|
90 |
| - mock_manifest = manifest_schema |
91 |
| - mock_load_manifest.return_value = mock_manifest |
| 153 | + auth_lambda = lambda: "Bearer token" # Define lambda once |
92 | 154 | with catch_warnings(record=True) as w:
|
93 | 155 | simplefilter("always")
|
94 | 156 | await mock_client.aload_tool(
|
95 |
| - tool_name, auth_headers={"Authorization": lambda: "Bearer token"} |
| 157 | + tool_name, |
| 158 | + auth_headers={"Authorization": auth_lambda}, # Use the defined lambda |
96 | 159 | )
|
97 | 160 | assert len(w) == 1
|
98 | 161 | assert issubclass(w[-1].category, DeprecationWarning)
|
99 | 162 | assert "auth_headers" in str(w[-1].message)
|
100 | 163 |
|
101 |
| - @patch("toolbox_langchain.async_client._load_manifest") |
| 164 | + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( |
| 165 | + name=tool_name, |
| 166 | + auth_token_getters={"Authorization": auth_lambda}, |
| 167 | + bound_params={}, |
| 168 | + ) |
| 169 | + |
102 | 170 | async def test_aload_tool_auth_headers_and_tokens(
|
103 |
| - self, mock_load_manifest, mock_client, manifest_schema |
| 171 | + self, mock_client, manifest_schema |
104 | 172 | ):
|
105 | 173 | tool_name = "test_tool_1"
|
106 |
| - mock_manifest = manifest_schema |
107 |
| - mock_load_manifest.return_value = mock_manifest |
| 174 | + auth_getters = {"test": lambda: "token"} |
| 175 | + auth_headers_lambda = lambda: "Bearer token" # Define lambda once |
| 176 | + |
108 | 177 | with catch_warnings(record=True) as w:
|
109 | 178 | simplefilter("always")
|
110 | 179 | await mock_client.aload_tool(
|
111 | 180 | tool_name,
|
112 |
| - auth_headers={"Authorization": lambda: "Bearer token"}, |
113 |
| - auth_token_getters={"test": lambda: "token"}, |
| 181 | + auth_headers={ |
| 182 | + "Authorization": auth_headers_lambda |
| 183 | + }, # Use defined lambda |
| 184 | + auth_token_getters=auth_getters, |
114 | 185 | )
|
115 |
| - assert len(w) == 1 |
| 186 | + assert ( |
| 187 | + len(w) == 1 |
| 188 | + ) # Only one warning because auth_token_getters takes precedence |
116 | 189 | assert issubclass(w[-1].category, DeprecationWarning)
|
117 |
| - assert "auth_headers" in str(w[-1].message) |
| 190 | + assert "auth_headers" in str(w[-1].message) # Warning for auth_headers |
| 191 | + |
| 192 | + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( |
| 193 | + name=tool_name, auth_token_getters=auth_getters, bound_params={} |
| 194 | + ) |
118 | 195 |
|
119 |
| - @patch("toolbox_langchain.async_client._load_manifest") |
120 | 196 | async def test_aload_toolset(
|
121 |
| - self, mock_load_manifest, mock_client, mock_session, manifest_schema |
| 197 | + self, mock_client, manifest_schema # mock_session removed |
122 | 198 | ):
|
123 |
| - mock_manifest = manifest_schema |
124 |
| - mock_load_manifest.return_value = mock_manifest |
125 | 199 | tools = await mock_client.aload_toolset()
|
126 | 200 |
|
127 |
| - mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session) |
128 |
| - assert len(tools) == 2 |
| 201 | + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( |
| 202 | + name=None, auth_token_getters={}, bound_params={}, strict=False |
| 203 | + ) |
| 204 | + assert len(tools) == 2 # Based on MANIFEST_JSON |
129 | 205 | for tool in tools:
|
130 | 206 | assert isinstance(tool, AsyncToolboxTool)
|
131 | 207 | assert tool.name in ["test_tool_1", "test_tool_2"]
|
132 | 208 |
|
133 |
| - @patch("toolbox_langchain.async_client._load_manifest") |
134 | 209 | async def test_aload_toolset_with_toolset_name(
|
135 |
| - self, mock_load_manifest, mock_client, mock_session, manifest_schema |
| 210 | + self, mock_client, manifest_schema # mock_session removed |
136 | 211 | ):
|
137 |
| - toolset_name = "test_toolset_1" |
138 |
| - mock_manifest = manifest_schema |
139 |
| - mock_load_manifest.return_value = mock_manifest |
| 212 | + toolset_name = "test_toolset_1" # This name isn't in MANIFEST_JSON, but load_toolset mock doesn't filter by it |
140 | 213 | tools = await mock_client.aload_toolset(toolset_name=toolset_name)
|
141 | 214 |
|
142 |
| - mock_load_manifest.assert_called_once_with( |
143 |
| - f"{URL}/api/toolset/{toolset_name}", mock_session |
| 215 | + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( |
| 216 | + name=toolset_name, auth_token_getters={}, bound_params={}, strict=False |
144 | 217 | )
|
145 | 218 | assert len(tools) == 2
|
146 | 219 | for tool in tools:
|
147 | 220 | assert isinstance(tool, AsyncToolboxTool)
|
148 | 221 | assert tool.name in ["test_tool_1", "test_tool_2"]
|
149 | 222 |
|
150 |
| - @patch("toolbox_langchain.async_client._load_manifest") |
151 | 223 | async def test_aload_toolset_auth_headers_deprecated(
|
152 |
| - self, mock_load_manifest, mock_client, manifest_schema |
| 224 | + self, mock_client, manifest_schema |
153 | 225 | ):
|
154 |
| - mock_manifest = manifest_schema |
155 |
| - mock_load_manifest.return_value = mock_manifest |
| 226 | + auth_lambda = lambda: "Bearer token" # Define lambda once |
156 | 227 | with catch_warnings(record=True) as w:
|
157 | 228 | simplefilter("always")
|
158 | 229 | await mock_client.aload_toolset(
|
159 |
| - auth_headers={"Authorization": lambda: "Bearer token"} |
| 230 | + auth_headers={"Authorization": auth_lambda} # Use defined lambda |
160 | 231 | )
|
161 | 232 | assert len(w) == 1
|
162 | 233 | assert issubclass(w[-1].category, DeprecationWarning)
|
163 | 234 | assert "auth_headers" in str(w[-1].message)
|
| 235 | + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( |
| 236 | + name=None, |
| 237 | + auth_token_getters={"Authorization": auth_lambda}, |
| 238 | + bound_params={}, |
| 239 | + strict=False, |
| 240 | + ) |
164 | 241 |
|
165 |
| - @patch("toolbox_langchain.async_client._load_manifest") |
166 | 242 | async def test_aload_toolset_auth_headers_and_tokens(
|
167 |
| - self, mock_load_manifest, mock_client, manifest_schema |
| 243 | + self, mock_client, manifest_schema |
168 | 244 | ):
|
169 |
| - mock_manifest = manifest_schema |
170 |
| - mock_load_manifest.return_value = mock_manifest |
| 245 | + auth_getters = {"test": lambda: "token"} |
| 246 | + auth_headers_lambda = lambda: "Bearer token" # Define lambda once |
171 | 247 | with catch_warnings(record=True) as w:
|
172 | 248 | simplefilter("always")
|
173 | 249 | await mock_client.aload_toolset(
|
174 |
| - auth_headers={"Authorization": lambda: "Bearer token"}, |
175 |
| - auth_token_getters={"test": lambda: "token"}, |
| 250 | + auth_headers={ |
| 251 | + "Authorization": auth_headers_lambda |
| 252 | + }, # Use defined lambda |
| 253 | + auth_token_getters=auth_getters, |
176 | 254 | )
|
177 | 255 | assert len(w) == 1
|
178 | 256 | assert issubclass(w[-1].category, DeprecationWarning)
|
179 | 257 | assert "auth_headers" in str(w[-1].message)
|
| 258 | + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( |
| 259 | + name=None, auth_token_getters=auth_getters, bound_params={}, strict=False |
| 260 | + ) |
180 | 261 |
|
181 | 262 | async def test_load_tool_not_implemented(self, mock_client):
|
182 | 263 | with pytest.raises(NotImplementedError) as excinfo:
|
|
0 commit comments