44from unittest .mock import AsyncMock , Mock
55
66
7+ def create_initialize_request (client_name : str ) -> mt .InitializeRequest :
8+ """Create a real InitializeRequest object."""
9+ return mt .InitializeRequest (
10+ method = 'initialize' ,
11+ params = mt .InitializeRequestParams (
12+ protocolVersion = '2024-11-05' ,
13+ capabilities = mt .ClientCapabilities (),
14+ clientInfo = mt .Implementation (name = client_name , version = '1.0' ),
15+ ),
16+ )
17+
18+
719@pytest .mark .asyncio
820async def test_on_initialize_connects_client ():
921 """Test that on_initialize calls client._connect()."""
@@ -17,7 +29,7 @@ async def test_on_initialize_connects_client():
1729 middleware = InitializeMiddleware (mock_factory )
1830
1931 mock_context = Mock ()
20- mock_context .message = Mock ( spec = mt . InitializeRequest )
32+ mock_context .message = create_initialize_request ( 'test-client' )
2133
2234 mock_call_next = AsyncMock ()
2335
@@ -42,11 +54,45 @@ async def test_on_initialize_fails_if_connect_fails():
4254 middleware = InitializeMiddleware (mock_factory )
4355
4456 mock_context = Mock ()
45- mock_context .message = Mock ( spec = mt . InitializeRequest )
57+ mock_context .message = create_initialize_request ( 'test-client' )
4658
4759 mock_call_next = AsyncMock ()
4860
4961 with pytest .raises (Exception , match = 'Connection failed' ):
5062 await middleware .on_initialize (mock_context , mock_call_next )
5163
5264 mock_call_next .assert_not_called ()
65+
66+
67+ @pytest .mark .asyncio
68+ @pytest .mark .parametrize (
69+ 'client_name' ,
70+ [
71+ 'Kiro CLI' ,
72+ 'kiro cli' ,
73+ 'KIRO CLI' ,
74+ 'Amazon Q Dev CLI' ,
75+ 'amazon q dev cli' ,
76+ 'Q DEV CLI' ,
77+ ],
78+ )
79+ async def test_on_initialize_skips_connect_for_special_clients (client_name ):
80+ """Test that on_initialize skips _connect() for Kiro CLI and Q Dev CLI."""
81+ mock_client = Mock ()
82+ mock_client ._connect = AsyncMock ()
83+
84+ mock_factory = Mock ()
85+ mock_factory .set_init_params = Mock ()
86+ mock_factory .get_client = AsyncMock (return_value = mock_client )
87+
88+ middleware = InitializeMiddleware (mock_factory )
89+
90+ mock_context = Mock ()
91+ mock_context .message = create_initialize_request (client_name )
92+
93+ mock_call_next = AsyncMock ()
94+
95+ await middleware .on_initialize (mock_context , mock_call_next )
96+
97+ mock_client ._connect .assert_not_called ()
98+ mock_call_next .assert_called_once_with (mock_context )
0 commit comments