@@ -24,15 +24,24 @@ async def test_gateway_client_init():
2424 client = GatewayClient (
2525 base_url = "http://localhost:8888/" ,
2626 default_model = "test-model" ,
27- polling_interval = 0.01 ,
27+ polling_interval = 0.5 ,
2828 timeout = 0.1 ,
2929 )
3030 assert client .base_url == "http://localhost:8888"
3131 assert client .default_model == "test-model"
32- assert client .polling_interval == 0.01
32+ assert client .polling_interval == 0.5
3333 assert client .timeout == 0.1
3434 assert client ._client is None
3535
36+ client = GatewayClient (
37+ base_url = "http://localhost:8888/" ,
38+ polling_interval = 10 ,
39+ )
40+ assert client .polling_interval == 3.0 # Maximum 3.0 seconds
41+
42+ client .polling_interval = 0.05
43+ assert client .polling_interval == 0.5 # Minimum is 0.5 seconds
44+
3645
3746@pytest .mark .asyncio
3847async def test_gateway_client_aenter_aexit (mock_httpx_async_client ):
@@ -65,16 +74,18 @@ async def test_submit_task_success(mock_httpx_async_client):
6574 _ , mock_client_instance = mock_httpx_async_client
6675 mock_response = MagicMock ()
6776 mock_response .json .return_value = {"uuid" : "task-123" , "status" : "pending" }
68- mock_client_instance .post .return_value = mock_response
77+ mock_response .raise_for_status .return_value = mock_response
78+ mock_client_instance .request .return_value = mock_response
6979
7080 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
7181 task_info = await client .submit_task (
7282 model_name = "my_model" , task = "process" , data = "some text"
7383 )
7484
7585 assert task_info == {"uuid" : "task-123" , "status" : "pending" }
76- mock_client_instance .post .assert_awaited_once_with (
77- "http://test-gateway.com/models/my_model/tasks/process" ,
86+ mock_client_instance .request .assert_awaited_once_with (
87+ method = "POST" ,
88+ url = "http://test-gateway.com/models/my_model/tasks/process" ,
7889 data = "some text" ,
7990 json = None ,
8091 files = None ,
@@ -90,16 +101,18 @@ async def test_submit_task_with_default_model(mock_httpx_async_client):
90101 _ , mock_client_instance = mock_httpx_async_client
91102 mock_response = MagicMock ()
92103 mock_response .json .return_value = {"uuid" : "task-456" , "status" : "pending" }
93- mock_client_instance .post .return_value = mock_response
104+ mock_response .raise_for_status .return_value = mock_response
105+ mock_client_instance .request .return_value = mock_response
94106
95107 async with GatewayClient (
96108 base_url = "http://test-gateway.com" , default_model = "default_model"
97109 ) as client :
98110 task_info = await client .submit_task (task = "process" , data = "some text" )
99111
100112 assert task_info == {"uuid" : "task-456" , "status" : "pending" }
101- mock_client_instance .post .assert_awaited_once_with (
102- "http://test-gateway.com/models/default_model/tasks/process" ,
113+ mock_client_instance .request .assert_awaited_once_with (
114+ method = "POST" ,
115+ url = "http://test-gateway.com/models/default_model/tasks/process" ,
103116 data = "some text" ,
104117 json = None ,
105118 files = None ,
@@ -124,7 +137,7 @@ async def test_submit_task_http_error(mock_httpx_async_client):
124137 mock_response .raise_for_status .side_effect = httpx .HTTPStatusError (
125138 "Bad Request" , request = httpx .Request ("POST" , "url" ), response = httpx .Response (400 )
126139 )
127- mock_client_instance .post .return_value = mock_response
140+ mock_client_instance .request .return_value = mock_response
128141
129142 with pytest .raises (httpx .HTTPStatusError ):
130143 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
@@ -137,7 +150,8 @@ async def test_submit_task_wait_for_completion_and_return_result(mock_httpx_asyn
137150 _ , mock_client_instance = mock_httpx_async_client
138151 mock_response = MagicMock ()
139152 mock_response .json .return_value = {"uuid" : "task-123" , "status" : "pending" }
140- mock_client_instance .post .return_value = mock_response
153+ mock_response .raise_for_status .return_value = mock_response
154+ mock_client_instance .request .return_value = mock_response
141155
142156 mock_wait_for_task = mocker .patch (
143157 "client.cogstack_model_gateway_client.client.GatewayClient.wait_for_task" ,
@@ -220,14 +234,20 @@ async def test_get_task_success(mock_httpx_async_client):
220234 mock_response = MagicMock ()
221235 mock_response .raise_for_status .return_value = mock_response
222236 mock_response .json .return_value = {"uuid" : "task-123" , "status" : "succeeded" }
223- mock_client_instance .get .return_value = mock_response
237+ mock_client_instance .request .return_value = mock_response
224238
225239 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
226240 task_info = await client .get_task ("task-123" )
227241
228242 assert task_info == {"uuid" : "task-123" , "status" : "succeeded" }
229- mock_client_instance .get .assert_awaited_once_with (
230- "http://test-gateway.com/tasks/task-123" , params = {"detail" : True , "download" : False }
243+ mock_client_instance .request .assert_awaited_once_with (
244+ method = "GET" ,
245+ url = "http://test-gateway.com/tasks/task-123" ,
246+ params = {"detail" : True , "download" : False },
247+ data = None ,
248+ json = None ,
249+ files = None ,
250+ headers = None ,
231251 )
232252 mock_response .raise_for_status .assert_called_once ()
233253
@@ -239,7 +259,7 @@ async def test_get_task_result_json(mock_httpx_async_client):
239259 mock_response = MagicMock ()
240260 mock_response .content = b'{"key": "value"}'
241261 mock_response .raise_for_status .return_value = mock_response
242- mock_client_instance .get .return_value = mock_response
262+ mock_client_instance .request .return_value = mock_response
243263
244264 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
245265 result = await client .get_task_result ("task-123" )
@@ -253,7 +273,7 @@ async def test_get_task_result_jsonl(mock_httpx_async_client):
253273 mock_response = MagicMock ()
254274 mock_response .content = b'{"item": 1}\n {"item": 2}\n '
255275 mock_response .raise_for_status .return_value = mock_response
256- mock_client_instance .get .return_value = mock_response
276+ mock_client_instance .request .return_value = mock_response
257277
258278 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
259279 result = await client .get_task_result ("task-123" )
@@ -267,7 +287,7 @@ async def test_get_task_result_text(mock_httpx_async_client):
267287 mock_response = MagicMock ()
268288 mock_response .content = b"plain text result"
269289 mock_response .raise_for_status .return_value = mock_response
270- mock_client_instance .get .return_value = mock_response
290+ mock_client_instance .request .return_value = mock_response
271291
272292 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
273293 result = await client .get_task_result ("task-123" )
@@ -281,7 +301,7 @@ async def test_get_task_result_binary(mock_httpx_async_client):
281301 mock_response = MagicMock ()
282302 mock_response .content = b"\x80 \x01 \x02 \x03 " # Example binary data
283303 mock_response .raise_for_status .return_value = mock_response
284- mock_client_instance .get .return_value = mock_response
304+ mock_client_instance .request .return_value = mock_response
285305
286306 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
287307 result = await client .get_task_result ("task-123" )
@@ -295,7 +315,7 @@ async def test_get_task_result_no_parse(mock_httpx_async_client):
295315 mock_response = MagicMock ()
296316 mock_response .content = b'{"key": "value"}'
297317 mock_response .raise_for_status .return_value = mock_response
298- mock_client_instance .get .return_value = mock_response
318+ mock_client_instance .request .return_value = mock_response
299319
300320 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
301321 result = await client .get_task_result ("task-123" , parse = False )
@@ -334,7 +354,7 @@ async def test_wait_for_task_timeout(mock_httpx_async_client, mocker):
334354
335355 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
336356 client .timeout = 0.05
337- client .polling_interval = 0.01
357+ client .polling_interval = 0.5
338358
339359 with pytest .raises (
340360 TimeoutError , match = "Timed out waiting for task 'task-polling' to complete"
@@ -394,13 +414,20 @@ async def test_get_models_success(mock_httpx_async_client):
394414 _ , mock_client_instance = mock_httpx_async_client
395415 mock_response = MagicMock ()
396416 mock_response .json .return_value = ["model_a" , "model_b" ]
397- mock_client_instance .get .return_value = mock_response
417+ mock_response .raise_for_status .return_value = mock_response
418+ mock_client_instance .request .return_value = mock_response
398419
399420 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
400421 models = await client .get_models ()
401422 assert models == ["model_a" , "model_b" ]
402- mock_client_instance .get .assert_awaited_once_with (
403- "http://test-gateway.com/models/" , params = {"verbose" : False }
423+ mock_client_instance .request .assert_awaited_once_with (
424+ method = "GET" ,
425+ url = "http://test-gateway.com/models/" ,
426+ params = {"verbose" : False },
427+ data = None ,
428+ json = None ,
429+ files = None ,
430+ headers = None ,
404431 )
405432
406433
@@ -410,13 +437,20 @@ async def test_get_model_success(mock_httpx_async_client):
410437 _ , mock_client_instance = mock_httpx_async_client
411438 mock_response = MagicMock ()
412439 mock_response .json .return_value = {"name" : "my_model" , "status" : "deployed" }
413- mock_client_instance .get .return_value = mock_response
440+ mock_response .raise_for_status .return_value = mock_response
441+ mock_client_instance .request .return_value = mock_response
414442
415443 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
416444 model_info = await client .get_model (model_name = "my_model" )
417445 assert model_info == {"name" : "my_model" , "status" : "deployed" }
418- mock_client_instance .get .assert_awaited_once_with (
419- "http://test-gateway.com/models/my_model/info"
446+ mock_client_instance .request .assert_awaited_once_with (
447+ method = "GET" ,
448+ url = "http://test-gateway.com/models/my_model/info" ,
449+ params = None ,
450+ data = None ,
451+ json = None ,
452+ files = None ,
453+ headers = None ,
420454 )
421455
422456
@@ -426,15 +460,22 @@ async def test_get_model_with_default_model(mock_httpx_async_client):
426460 _ , mock_client_instance = mock_httpx_async_client
427461 mock_response = MagicMock ()
428462 mock_response .json .return_value = {"name" : "default_model" , "status" : "deployed" }
429- mock_client_instance .get .return_value = mock_response
463+ mock_response .raise_for_status .return_value = mock_response
464+ mock_client_instance .request .return_value = mock_response
430465
431466 async with GatewayClient (
432467 base_url = "http://test-gateway.com" , default_model = "default_model"
433468 ) as client :
434469 model_info = await client .get_model ()
435470 assert model_info == {"name" : "default_model" , "status" : "deployed" }
436- mock_client_instance .get .assert_awaited_once_with (
437- "http://test-gateway.com/models/default_model/info"
471+ mock_client_instance .request .assert_awaited_once_with (
472+ method = "GET" ,
473+ url = "http://test-gateway.com/models/default_model/info" ,
474+ params = None ,
475+ data = None ,
476+ json = None ,
477+ files = None ,
478+ headers = None ,
438479 )
439480
440481
@@ -452,7 +493,8 @@ async def test_deploy_model_success(mock_httpx_async_client):
452493 _ , mock_client_instance = mock_httpx_async_client
453494 mock_response = MagicMock ()
454495 mock_response .json .return_value = {"name" : "new_model" , "status" : "deploying" }
455- mock_client_instance .post .return_value = mock_response
496+ mock_response .raise_for_status .return_value = mock_response
497+ mock_client_instance .request .return_value = mock_response
456498
457499 async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
458500 deploy_info = await client .deploy_model (
@@ -462,11 +504,16 @@ async def test_deploy_model_success(mock_httpx_async_client):
462504 )
463505
464506 assert deploy_info == {"name" : "new_model" , "status" : "deploying" }
465- mock_client_instance .post .assert_awaited_once_with (
466- "http://test-gateway.com/models/new_model" ,
507+ mock_client_instance .request .assert_awaited_once_with (
508+ method = "POST" ,
509+ url = "http://test-gateway.com/models/new_model" ,
467510 json = {
468511 "tracking_id" : None ,
469512 "model_uri" : "mlflow-artifacts:/1/runidabcd1234/artifacts/new_model" ,
470513 "ttl" : 3600 ,
471514 },
515+ params = None ,
516+ data = None ,
517+ files = None ,
518+ headers = None ,
472519 )
0 commit comments