11"""Tests for tools module."""
22
33import pytest
4- from unittest .mock import AsyncMock , patch , PropertyMock
4+ from unittest .mock import AsyncMock , patch , PropertyMock , Mock
55import httpx
66from pia_mcp_server .tools .pia_search import handle_pia_search
7+ from pia_mcp_server .tools .pia_search_facets import handle_pia_search_facets
78from pia_mcp_server .config import Settings
89
910settings = Settings ()
1213@pytest .mark .asyncio
1314async def test_pia_search_no_api_key ():
1415 """Test PIA search without API key."""
15- with patch (
16- "pia_mcp_server.tools.pia_search.settings.API_KEY" ,
17- side_effect = ValueError ("PIA API key is required" ),
18- ):
16+ with patch .object (Settings , "_get_api_key_from_args" , return_value = None ):
1917 result = await handle_pia_search ({"query" : "test" })
2018
2119 assert len (result ) == 1
@@ -36,16 +34,15 @@ async def test_pia_search_success():
3634 },
3735 }
3836
39- with patch .object (
40- type (settings ),
41- "API_KEY" ,
42- new_callable = lambda : PropertyMock (return_value = "test_key" ),
43- ):
44- with patch ("httpx.AsyncClient.post" ) as mock_post :
45- mock_response_obj = AsyncMock ()
37+ with patch .object (Settings , "_get_api_key_from_args" , return_value = "test_key" ):
38+ with patch ("httpx.AsyncClient" ) as mock_client :
39+ mock_response_obj = Mock ()
4640 mock_response_obj .json .return_value = mock_response
47- mock_response_obj .raise_for_status = AsyncMock ()
48- mock_post .return_value = mock_response_obj
41+ mock_response_obj .raise_for_status .return_value = None
42+
43+ mock_client_instance = AsyncMock ()
44+ mock_client_instance .post .return_value = mock_response_obj
45+ mock_client .return_value .__aenter__ .return_value = mock_client_instance
4946
5047 result = await handle_pia_search ({"query" : "test fraud" })
5148
@@ -62,16 +59,15 @@ async def test_pia_search_api_error():
6259 "error" : {"code" : - 32000 , "message" : "Invalid API key" },
6360 }
6461
65- with patch .object (
66- type (settings ),
67- "API_KEY" ,
68- new_callable = lambda : PropertyMock (return_value = "invalid_key" ),
69- ):
70- with patch ("httpx.AsyncClient.post" ) as mock_post :
71- mock_response_obj = AsyncMock ()
62+ with patch .object (Settings , "_get_api_key_from_args" , return_value = "invalid_key" ):
63+ with patch ("httpx.AsyncClient" ) as mock_client :
64+ mock_response_obj = Mock ()
7265 mock_response_obj .json .return_value = mock_response
73- mock_response_obj .raise_for_status = AsyncMock ()
74- mock_post .return_value = mock_response_obj
66+ mock_response_obj .raise_for_status .return_value = None
67+
68+ mock_client_instance = AsyncMock ()
69+ mock_client_instance .post .return_value = mock_response_obj
70+ mock_client .return_value .__aenter__ .return_value = mock_client_instance
7571
7672 result = await handle_pia_search ({"query" : "test" })
7773
@@ -82,22 +78,112 @@ async def test_pia_search_api_error():
8278@pytest .mark .asyncio
8379async def test_pia_search_http_error ():
8480 """Test PIA search with HTTP error."""
85- with patch .object (
86- type (settings ),
87- "API_KEY" ,
88- new_callable = lambda : PropertyMock (return_value = "test_key" ),
89- ):
90- with patch ("httpx.AsyncClient.post" ) as mock_post :
81+ with patch .object (Settings , "_get_api_key_from_args" , return_value = "test_key" ):
82+ with patch ("httpx.AsyncClient" ) as mock_client :
9183 mock_response_obj = AsyncMock ()
9284 mock_response_obj .status_code = 500
9385 mock_response_obj .text = "Internal Server Error"
9486
9587 http_error = httpx .HTTPStatusError (
9688 "500 Server Error" , request = AsyncMock (), response = mock_response_obj
9789 )
98- mock_post .side_effect = http_error
90+
91+ mock_client_instance = AsyncMock ()
92+ mock_client_instance .post .side_effect = http_error
93+ mock_client .return_value .__aenter__ .return_value = mock_client_instance
9994
10095 result = await handle_pia_search ({"query" : "test" })
10196
10297 assert len (result ) == 1
10398 assert "HTTP Error 500" in result [0 ].text
99+
100+
101+ @pytest .mark .asyncio
102+ async def test_pia_search_facets_no_api_key ():
103+ """Test PIA search facets without API key."""
104+ with patch .object (Settings , "_get_api_key_from_args" , return_value = None ):
105+ result = await handle_pia_search_facets ({"query" : "test" })
106+
107+ assert len (result ) == 1
108+ assert "PIA API key is required" in result [0 ].text
109+
110+
111+ @pytest .mark .asyncio
112+ async def test_pia_search_facets_success ():
113+ """Test successful PIA search facets."""
114+ mock_response = {
115+ "jsonrpc" : "2.0" ,
116+ "id" : 1 ,
117+ "result" : {
118+ "facets" : {
119+ "data_source" : ["OIG" , "GAO" , "CMS" ],
120+ "document_type" : ["audit_report" , "investigation" , "guidance" ],
121+ "agency" : ["HHS" , "DOD" , "VA" ],
122+ }
123+ },
124+ }
125+
126+ with patch .object (Settings , "_get_api_key_from_args" , return_value = "test_key" ):
127+ with patch ("httpx.AsyncClient" ) as mock_client :
128+ mock_response_obj = Mock ()
129+ mock_response_obj .json .return_value = mock_response
130+ mock_response_obj .raise_for_status .return_value = None
131+
132+ mock_client_instance = AsyncMock ()
133+ mock_client_instance .post .return_value = mock_response_obj
134+ mock_client .return_value .__aenter__ .return_value = mock_client_instance
135+
136+ result = await handle_pia_search_facets ({"query" : "healthcare" })
137+
138+ assert len (result ) == 1
139+ assert "data_source" in result [0 ].text
140+ assert "OIG" in result [0 ].text
141+ assert "document_type" in result [0 ].text
142+
143+
144+ @pytest .mark .asyncio
145+ async def test_pia_search_facets_api_error ():
146+ """Test PIA search facets with API error."""
147+ mock_response = {
148+ "jsonrpc" : "2.0" ,
149+ "id" : 1 ,
150+ "error" : {"code" : - 32000 , "message" : "Invalid query format" },
151+ }
152+
153+ with patch .object (Settings , "_get_api_key_from_args" , return_value = "invalid_key" ):
154+ with patch ("httpx.AsyncClient" ) as mock_client :
155+ mock_response_obj = Mock ()
156+ mock_response_obj .json .return_value = mock_response
157+ mock_response_obj .raise_for_status .return_value = None
158+
159+ mock_client_instance = AsyncMock ()
160+ mock_client_instance .post .return_value = mock_response_obj
161+ mock_client .return_value .__aenter__ .return_value = mock_client_instance
162+
163+ result = await handle_pia_search_facets ({"query" : "test" })
164+
165+ assert len (result ) == 1
166+ assert "API Error: Invalid query format" in result [0 ].text
167+
168+
169+ @pytest .mark .asyncio
170+ async def test_pia_search_facets_http_error ():
171+ """Test PIA search facets with HTTP error."""
172+ with patch .object (Settings , "_get_api_key_from_args" , return_value = "test_key" ):
173+ with patch ("httpx.AsyncClient" ) as mock_client :
174+ mock_response_obj = AsyncMock ()
175+ mock_response_obj .status_code = 403
176+ mock_response_obj .text = "Forbidden"
177+
178+ http_error = httpx .HTTPStatusError (
179+ "403 Client Error" , request = AsyncMock (), response = mock_response_obj
180+ )
181+
182+ mock_client_instance = AsyncMock ()
183+ mock_client_instance .post .side_effect = http_error
184+ mock_client .return_value .__aenter__ .return_value = mock_client_instance
185+
186+ result = await handle_pia_search_facets ({"query" : "test" })
187+
188+ assert len (result ) == 1
189+ assert "HTTP Error 403" in result [0 ].text
0 commit comments