|
12 | 12 | HookMatcher, |
13 | 13 | PermissionResultAllow, |
14 | 14 | PermissionResultDeny, |
| 15 | + PermissionRuleValue, |
| 16 | + PermissionUpdate, |
15 | 17 | ToolPermissionContext, |
16 | 18 | ) |
17 | 19 | from claude_agent_sdk._internal.query import Query |
@@ -207,6 +209,156 @@ async def error_callback( |
207 | 209 | assert '"subtype": "error"' in response |
208 | 210 | assert "Callback error" in response |
209 | 211 |
|
| 212 | + @pytest.mark.asyncio |
| 213 | + async def test_permission_callback_with_updated_permissions(self): |
| 214 | + """Test callback that returns allow with updated permissions (Always Allow).""" |
| 215 | + |
| 216 | + async def allow_with_permissions_callback( |
| 217 | + tool_name: str, input_data: dict, context: ToolPermissionContext |
| 218 | + ) -> PermissionResultAllow: |
| 219 | + # Return allow with permission updates for "Always Allow" functionality |
| 220 | + return PermissionResultAllow( |
| 221 | + updated_permissions=[ |
| 222 | + PermissionUpdate( |
| 223 | + type="addRules", |
| 224 | + behavior="allow", |
| 225 | + rules=[ |
| 226 | + PermissionRuleValue(tool_name="Bash", rule_content=None) |
| 227 | + ], |
| 228 | + destination="session", |
| 229 | + ) |
| 230 | + ] |
| 231 | + ) |
| 232 | + |
| 233 | + transport = MockTransport() |
| 234 | + query = Query( |
| 235 | + transport=transport, |
| 236 | + is_streaming_mode=True, |
| 237 | + can_use_tool=allow_with_permissions_callback, |
| 238 | + hooks=None, |
| 239 | + ) |
| 240 | + |
| 241 | + request = { |
| 242 | + "type": "control_request", |
| 243 | + "request_id": "test-4", |
| 244 | + "request": { |
| 245 | + "subtype": "can_use_tool", |
| 246 | + "tool_name": "Bash", |
| 247 | + "input": {"command": "ls -la"}, |
| 248 | + "permission_suggestions": [], |
| 249 | + }, |
| 250 | + } |
| 251 | + |
| 252 | + await query._handle_control_request(request) |
| 253 | + |
| 254 | + # Check response includes updatedPermissions |
| 255 | + assert len(transport.written_messages) == 1 |
| 256 | + response = transport.written_messages[0] |
| 257 | + response_data = json.loads(response) |
| 258 | + |
| 259 | + # Get the nested response data |
| 260 | + result = response_data["response"]["response"] |
| 261 | + |
| 262 | + assert result.get("behavior") == "allow" |
| 263 | + assert "updatedPermissions" in result |
| 264 | + assert len(result["updatedPermissions"]) == 1 |
| 265 | + assert result["updatedPermissions"][0]["type"] == "addRules" |
| 266 | + assert result["updatedPermissions"][0]["behavior"] == "allow" |
| 267 | + assert result["updatedPermissions"][0]["destination"] == "session" |
| 268 | + |
| 269 | + @pytest.mark.asyncio |
| 270 | + async def test_permission_callback_deny_with_interrupt(self): |
| 271 | + """Test callback that denies with interrupt flag to stop execution.""" |
| 272 | + |
| 273 | + async def deny_with_interrupt_callback( |
| 274 | + tool_name: str, input_data: dict, context: ToolPermissionContext |
| 275 | + ) -> PermissionResultDeny: |
| 276 | + # Deny and interrupt - stop the agent completely |
| 277 | + return PermissionResultDeny( |
| 278 | + message="Critical security violation - stopping agent", |
| 279 | + interrupt=True, |
| 280 | + ) |
| 281 | + |
| 282 | + transport = MockTransport() |
| 283 | + query = Query( |
| 284 | + transport=transport, |
| 285 | + is_streaming_mode=True, |
| 286 | + can_use_tool=deny_with_interrupt_callback, |
| 287 | + hooks=None, |
| 288 | + ) |
| 289 | + |
| 290 | + request = { |
| 291 | + "type": "control_request", |
| 292 | + "request_id": "test-5-interrupt", |
| 293 | + "request": { |
| 294 | + "subtype": "can_use_tool", |
| 295 | + "tool_name": "DangerousTool", |
| 296 | + "input": {"command": "rm -rf /"}, |
| 297 | + "permission_suggestions": [], |
| 298 | + }, |
| 299 | + } |
| 300 | + |
| 301 | + await query._handle_control_request(request) |
| 302 | + |
| 303 | + # Check response includes interrupt flag |
| 304 | + assert len(transport.written_messages) == 1 |
| 305 | + response = transport.written_messages[0] |
| 306 | + response_data = json.loads(response) |
| 307 | + |
| 308 | + # Get the nested response data |
| 309 | + result = response_data["response"]["response"] |
| 310 | + |
| 311 | + assert result.get("behavior") == "deny" |
| 312 | + assert result.get("message") == "Critical security violation - stopping agent" |
| 313 | + assert result.get("interrupt") is True |
| 314 | + |
| 315 | + @pytest.mark.asyncio |
| 316 | + async def test_permission_callback_deny_without_interrupt(self): |
| 317 | + """Test callback that denies without interrupt (deny and continue).""" |
| 318 | + |
| 319 | + async def deny_without_interrupt_callback( |
| 320 | + tool_name: str, input_data: dict, context: ToolPermissionContext |
| 321 | + ) -> PermissionResultDeny: |
| 322 | + # Deny but don't interrupt - let the agent try a different approach |
| 323 | + return PermissionResultDeny( |
| 324 | + message="Tool not allowed, try a different approach", |
| 325 | + interrupt=False, |
| 326 | + ) |
| 327 | + |
| 328 | + transport = MockTransport() |
| 329 | + query = Query( |
| 330 | + transport=transport, |
| 331 | + is_streaming_mode=True, |
| 332 | + can_use_tool=deny_without_interrupt_callback, |
| 333 | + hooks=None, |
| 334 | + ) |
| 335 | + |
| 336 | + request = { |
| 337 | + "type": "control_request", |
| 338 | + "request_id": "test-6-no-interrupt", |
| 339 | + "request": { |
| 340 | + "subtype": "can_use_tool", |
| 341 | + "tool_name": "SomeTool", |
| 342 | + "input": {}, |
| 343 | + "permission_suggestions": [], |
| 344 | + }, |
| 345 | + } |
| 346 | + |
| 347 | + await query._handle_control_request(request) |
| 348 | + |
| 349 | + # Check response does NOT include interrupt flag when False |
| 350 | + assert len(transport.written_messages) == 1 |
| 351 | + response = transport.written_messages[0] |
| 352 | + response_data = json.loads(response) |
| 353 | + |
| 354 | + # Get the nested response data |
| 355 | + result = response_data["response"]["response"] |
| 356 | + |
| 357 | + assert result.get("behavior") == "deny" |
| 358 | + assert result.get("message") == "Tool not allowed, try a different approach" |
| 359 | + # interrupt should not be present when False |
| 360 | + assert "interrupt" not in result |
| 361 | + |
210 | 362 |
|
211 | 363 | class TestHookCallbacks: |
212 | 364 | """Test hook callback functionality.""" |
|
0 commit comments