Skip to content

Commit 5df105d

Browse files
committed
test: add tests for can_use_tool callback features (issue #159)
- Add test for PermissionResultAllow with updatedPermissions support - Add test for PermissionResultDeny with interrupt flag (deny and stop) - Add test for PermissionResultDeny without interrupt (deny and continue) - Export PermissionRuleValue type for creating permission updates These tests verify the can_use_tool callback correctly handles: 1. The behavior/updatedInput field names (not allow/input) 2. The updatedPermissions field for "Always Allow" functionality 3. The interrupt flag for stopping vs continuing after denial Fixes #159
1 parent 27575ae commit 5df105d

File tree

2 files changed

+154
-0
lines changed

2 files changed

+154
-0
lines changed

src/claude_agent_sdk/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
PermissionResult,
3535
PermissionResultAllow,
3636
PermissionResultDeny,
37+
PermissionRuleValue,
3738
PermissionUpdate,
3839
PostToolUseHookInput,
3940
PreCompactHookInput,
@@ -327,6 +328,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any:
327328
"PermissionResult",
328329
"PermissionResultAllow",
329330
"PermissionResultDeny",
331+
"PermissionRuleValue",
330332
"PermissionUpdate",
331333
# Hook support
332334
"HookCallback",

tests/test_tool_callbacks.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
HookMatcher,
1313
PermissionResultAllow,
1414
PermissionResultDeny,
15+
PermissionRuleValue,
16+
PermissionUpdate,
1517
ToolPermissionContext,
1618
)
1719
from claude_agent_sdk._internal.query import Query
@@ -207,6 +209,156 @@ async def error_callback(
207209
assert '"subtype": "error"' in response
208210
assert "Callback error" in response
209211

212+
@pytest.mark.asyncio
213+
async def test_permission_callback_with_updated_permissions(self):
214+
"""Test callback that returns allow with updated permissions (Always Allow)."""
215+
216+
async def allow_with_permissions_callback(
217+
tool_name: str, input_data: dict, context: ToolPermissionContext
218+
) -> PermissionResultAllow:
219+
# Return allow with permission updates for "Always Allow" functionality
220+
return PermissionResultAllow(
221+
updated_permissions=[
222+
PermissionUpdate(
223+
type="addRules",
224+
behavior="allow",
225+
rules=[
226+
PermissionRuleValue(tool_name="Bash", rule_content=None)
227+
],
228+
destination="session",
229+
)
230+
]
231+
)
232+
233+
transport = MockTransport()
234+
query = Query(
235+
transport=transport,
236+
is_streaming_mode=True,
237+
can_use_tool=allow_with_permissions_callback,
238+
hooks=None,
239+
)
240+
241+
request = {
242+
"type": "control_request",
243+
"request_id": "test-4",
244+
"request": {
245+
"subtype": "can_use_tool",
246+
"tool_name": "Bash",
247+
"input": {"command": "ls -la"},
248+
"permission_suggestions": [],
249+
},
250+
}
251+
252+
await query._handle_control_request(request)
253+
254+
# Check response includes updatedPermissions
255+
assert len(transport.written_messages) == 1
256+
response = transport.written_messages[0]
257+
response_data = json.loads(response)
258+
259+
# Get the nested response data
260+
result = response_data["response"]["response"]
261+
262+
assert result.get("behavior") == "allow"
263+
assert "updatedPermissions" in result
264+
assert len(result["updatedPermissions"]) == 1
265+
assert result["updatedPermissions"][0]["type"] == "addRules"
266+
assert result["updatedPermissions"][0]["behavior"] == "allow"
267+
assert result["updatedPermissions"][0]["destination"] == "session"
268+
269+
@pytest.mark.asyncio
270+
async def test_permission_callback_deny_with_interrupt(self):
271+
"""Test callback that denies with interrupt flag to stop execution."""
272+
273+
async def deny_with_interrupt_callback(
274+
tool_name: str, input_data: dict, context: ToolPermissionContext
275+
) -> PermissionResultDeny:
276+
# Deny and interrupt - stop the agent completely
277+
return PermissionResultDeny(
278+
message="Critical security violation - stopping agent",
279+
interrupt=True,
280+
)
281+
282+
transport = MockTransport()
283+
query = Query(
284+
transport=transport,
285+
is_streaming_mode=True,
286+
can_use_tool=deny_with_interrupt_callback,
287+
hooks=None,
288+
)
289+
290+
request = {
291+
"type": "control_request",
292+
"request_id": "test-5-interrupt",
293+
"request": {
294+
"subtype": "can_use_tool",
295+
"tool_name": "DangerousTool",
296+
"input": {"command": "rm -rf /"},
297+
"permission_suggestions": [],
298+
},
299+
}
300+
301+
await query._handle_control_request(request)
302+
303+
# Check response includes interrupt flag
304+
assert len(transport.written_messages) == 1
305+
response = transport.written_messages[0]
306+
response_data = json.loads(response)
307+
308+
# Get the nested response data
309+
result = response_data["response"]["response"]
310+
311+
assert result.get("behavior") == "deny"
312+
assert result.get("message") == "Critical security violation - stopping agent"
313+
assert result.get("interrupt") is True
314+
315+
@pytest.mark.asyncio
316+
async def test_permission_callback_deny_without_interrupt(self):
317+
"""Test callback that denies without interrupt (deny and continue)."""
318+
319+
async def deny_without_interrupt_callback(
320+
tool_name: str, input_data: dict, context: ToolPermissionContext
321+
) -> PermissionResultDeny:
322+
# Deny but don't interrupt - let the agent try a different approach
323+
return PermissionResultDeny(
324+
message="Tool not allowed, try a different approach",
325+
interrupt=False,
326+
)
327+
328+
transport = MockTransport()
329+
query = Query(
330+
transport=transport,
331+
is_streaming_mode=True,
332+
can_use_tool=deny_without_interrupt_callback,
333+
hooks=None,
334+
)
335+
336+
request = {
337+
"type": "control_request",
338+
"request_id": "test-6-no-interrupt",
339+
"request": {
340+
"subtype": "can_use_tool",
341+
"tool_name": "SomeTool",
342+
"input": {},
343+
"permission_suggestions": [],
344+
},
345+
}
346+
347+
await query._handle_control_request(request)
348+
349+
# Check response does NOT include interrupt flag when False
350+
assert len(transport.written_messages) == 1
351+
response = transport.written_messages[0]
352+
response_data = json.loads(response)
353+
354+
# Get the nested response data
355+
result = response_data["response"]["response"]
356+
357+
assert result.get("behavior") == "deny"
358+
assert result.get("message") == "Tool not allowed, try a different approach"
359+
# interrupt should not be present when False
360+
assert "interrupt" not in result
361+
210362

211363
class TestHookCallbacks:
212364
"""Test hook callback functionality."""

0 commit comments

Comments
 (0)