17
17
18
18
# Standard
19
19
from datetime import datetime , timezone
20
- from unittest .mock import AsyncMock , MagicMock , Mock
20
+ from unittest .mock import AsyncMock , MagicMock , Mock , patch
21
21
22
22
# Third-Party
23
23
import pytest
@@ -146,7 +146,7 @@ class TestGatewayService:
146
146
# ────────────────────────────────────────────────────────────────────
147
147
148
148
@pytest .mark .asyncio
149
- async def test_register_gateway (self , gateway_service , test_db ):
149
+ async def test_register_gateway (self , gateway_service , test_db , monkeypatch ):
150
150
"""Successful gateway registration populates DB and returns data."""
151
151
# DB: no gateway with that name; no existing tools found
152
152
test_db .execute = Mock (
@@ -172,6 +172,18 @@ async def test_register_gateway(self, gateway_service, test_db):
172
172
)
173
173
gateway_service ._notify_gateway_added = AsyncMock ()
174
174
175
+ # Patch GatewayRead.model_validate to return a mock with .masked()
176
+ mock_model = Mock ()
177
+ mock_model .masked .return_value = mock_model
178
+ mock_model .name = "test_gateway"
179
+ mock_model .url = "http://example.com/gateway"
180
+ mock_model .description = "A test gateway"
181
+
182
+ monkeypatch .setattr (
183
+ "mcpgateway.services.gateway_service.GatewayRead.model_validate" ,
184
+ lambda x : mock_model ,
185
+ )
186
+
175
187
gateway_create = GatewayCreate (
176
188
name = "test_gateway" ,
177
189
url = "http://example.com/gateway" ,
@@ -236,10 +248,18 @@ async def test_register_gateway_connection_error(self, gateway_service, test_db)
236
248
# ────────────────────────────────────────────────────────────────────
237
249
238
250
@pytest .mark .asyncio
239
- async def test_list_gateways (self , gateway_service , mock_gateway , test_db ):
251
+ async def test_list_gateways (self , gateway_service , mock_gateway , test_db , monkeypatch ):
240
252
"""Listing gateways returns the active ones."""
253
+
241
254
test_db .execute = Mock (return_value = _make_execute_result (scalars_list = [mock_gateway ]))
242
255
256
+ mock_model = Mock ()
257
+ mock_model .masked .return_value = mock_model
258
+ mock_model .name = "test_gateway"
259
+
260
+ # Patch using full path string to GatewayRead.model_validate
261
+ monkeypatch .setattr ("mcpgateway.services.gateway_service.GatewayRead.model_validate" , lambda x : mock_model )
262
+
243
263
result = await gateway_service .list_gateways (test_db )
244
264
245
265
test_db .execute .assert_called_once ()
@@ -249,6 +269,7 @@ async def test_list_gateways(self, gateway_service, mock_gateway, test_db):
249
269
@pytest .mark .asyncio
250
270
async def test_get_gateway (self , gateway_service , mock_gateway , test_db ):
251
271
"""Gateway is fetched and returned by ID."""
272
+ mock_gateway .masked = Mock (return_value = mock_gateway )
252
273
test_db .get = Mock (return_value = mock_gateway )
253
274
result = await gateway_service .get_gateway (test_db , 1 )
254
275
test_db .get .assert_called_once_with (DbGateway , 1 )
@@ -266,14 +287,24 @@ async def test_get_gateway_not_found(self, gateway_service, test_db):
266
287
async def test_get_gateway_inactive (self , gateway_service , mock_gateway , test_db ):
267
288
"""Inactive gateway is not returned unless explicitly asked for."""
268
289
mock_gateway .enabled = False
290
+ mock_gateway .id = 1
269
291
test_db .get = Mock (return_value = mock_gateway )
270
- result = await gateway_service .get_gateway (test_db , 1 , include_inactive = True )
271
- assert result .id == 1
272
- assert result .enabled == False
273
- test_db .get .reset_mock ()
274
- test_db .get = Mock (return_value = mock_gateway )
275
- with pytest .raises (GatewayNotFoundError ):
276
- result = await gateway_service .get_gateway (test_db , 1 , include_inactive = False )
292
+
293
+ # Create a mock for GatewayRead with a masked method
294
+ mock_gateway_read = Mock ()
295
+ mock_gateway_read .id = 1
296
+ mock_gateway_read .enabled = False
297
+ mock_gateway_read .masked = Mock (return_value = mock_gateway_read )
298
+
299
+ with patch ("mcpgateway.services.gateway_service.GatewayRead.model_validate" , return_value = mock_gateway_read ):
300
+ result = await gateway_service .get_gateway (test_db , 1 , include_inactive = True )
301
+ assert result .id == 1
302
+ assert result .enabled == False
303
+
304
+ # Now test the inactive = False path
305
+ test_db .get = Mock (return_value = mock_gateway )
306
+ with pytest .raises (GatewayNotFoundError ):
307
+ await gateway_service .get_gateway (test_db , 1 , include_inactive = False )
277
308
278
309
# ────────────────────────────────────────────────────────────────────
279
310
# UPDATE
@@ -288,22 +319,36 @@ async def test_update_gateway(self, gateway_service, mock_gateway, test_db):
288
319
test_db .commit = Mock ()
289
320
test_db .refresh = Mock ()
290
321
322
+ # Simulate successful gateway initialization
291
323
gateway_service ._initialize_gateway = AsyncMock (
292
324
return_value = (
293
- {"prompts" : {"subscribe" : True }, "resources" : {"subscribe" : True }, "tools" : {"subscribe" : True }},
325
+ {
326
+ "prompts" : {"subscribe" : True },
327
+ "resources" : {"subscribe" : True },
328
+ "tools" : {"subscribe" : True },
329
+ },
294
330
[],
295
331
)
296
332
)
297
333
gateway_service ._notify_gateway_updated = AsyncMock ()
298
334
335
+ # Create the update payload
299
336
gateway_update = GatewayUpdate (
300
337
name = "updated_gateway" ,
301
338
url = "http://example.com/updated" ,
302
339
description = "Updated description" ,
303
340
)
304
341
305
- result = await gateway_service .update_gateway (test_db , 1 , gateway_update )
342
+ # Create mock return for GatewayRead.model_validate().masked()
343
+ mock_gateway_read = MagicMock ()
344
+ mock_gateway_read .name = "updated_gateway"
345
+ mock_gateway_read .masked .return_value = mock_gateway_read # Ensure .masked() returns the same object
346
+
347
+ # Patch the model_validate call in the service
348
+ with patch ("mcpgateway.services.gateway_service.GatewayRead.model_validate" , return_value = mock_gateway_read ):
349
+ result = await gateway_service .update_gateway (test_db , 1 , gateway_update )
306
350
351
+ # Assertions
307
352
test_db .commit .assert_called_once ()
308
353
test_db .refresh .assert_called_once ()
309
354
gateway_service ._initialize_gateway .assert_called_once ()
@@ -354,6 +399,7 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d
354
399
query_proxy .filter .return_value = filter_proxy
355
400
test_db .query = Mock (return_value = query_proxy )
356
401
402
+ # Setup gateway service mocks
357
403
gateway_service ._notify_gateway_activated = AsyncMock ()
358
404
gateway_service ._notify_gateway_deactivated = AsyncMock ()
359
405
gateway_service ._initialize_gateway = AsyncMock (return_value = ({"prompts" : {}}, []))
@@ -362,12 +408,17 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d
362
408
tool_service_stub .toggle_tool_status = AsyncMock ()
363
409
gateway_service .tool_service = tool_service_stub
364
410
365
- result = await gateway_service .toggle_gateway_status (test_db , 1 , activate = False )
411
+ # Patch model_validate to return a mock with .masked()
412
+ mock_gateway_read = MagicMock ()
413
+ mock_gateway_read .masked .return_value = mock_gateway_read
414
+
415
+ with patch ("mcpgateway.services.gateway_service.GatewayRead.model_validate" , return_value = mock_gateway_read ):
416
+ result = await gateway_service .toggle_gateway_status (test_db , 1 , activate = False )
366
417
367
418
assert mock_gateway .enabled is False
368
419
gateway_service ._notify_gateway_deactivated .assert_called_once ()
369
420
assert tool_service_stub .toggle_tool_status .called
370
- assert result . enabled is False
421
+ assert result == mock_gateway_read
371
422
372
423
# ────────────────────────────────────────────────────────────────────
373
424
# DELETE
0 commit comments