1+ import json
2+
13import pytest
4+ import requests
25from fastapi .testclient import TestClient
3- from testcontainers .minio import MinioContainer
4- from testcontainers .postgres import PostgresContainer
5- from testcontainers .rabbitmq import RabbitMqContainer
66
7+ from cogstack_model_gateway .common .config import Config , load_config
8+ from cogstack_model_gateway .common .object_store import ObjectStoreManager
9+ from cogstack_model_gateway .common .queue import QueueManager
10+ from cogstack_model_gateway .common .tasks import Status , TaskManager
711from cogstack_model_gateway .gateway .main import app
812from tests .integration .utils import (
9- clone_cogstack_model_serve ,
13+ TEST_MODEL_SERVICE ,
1014 configure_environment ,
11- remove_cogstack_model_serve ,
12- remove_testcontainers ,
13- start_cogstack_model_serve ,
14- start_scheduler ,
15- start_testcontainers ,
16- stop_cogstack_model_serve ,
17- stop_scheduler ,
15+ setup_cms ,
16+ setup_scheduler ,
17+ setup_testcontainers ,
1818)
1919
20- POSTGRES_IMAGE = "postgres:17.2"
21- RABBITMQ_IMAGE = "rabbitmq:4.0.4-management-alpine"
22- MINIO_IMAGE = "minio/minio:RELEASE.2024-11-07T00-52-20Z"
23-
2420
2521@pytest .fixture (scope = "module" , autouse = True )
26- def setup (request ):
27- postgres = PostgresContainer (POSTGRES_IMAGE )
28- rabbitmq = RabbitMqContainer (RABBITMQ_IMAGE )
29- minio = MinioContainer (MINIO_IMAGE )
30-
31- containers = [postgres , rabbitmq , minio ]
32- request .addfinalizer (lambda : remove_testcontainers (containers ))
22+ def setup (request : pytest .FixtureRequest , cleanup_cms : bool ):
23+ postgres , rabbitmq , minio = setup_testcontainers (request )
3324
34- start_testcontainers (containers )
25+ svc_addr_map = setup_cms (request , cleanup_cms )
26+ request .config .cache .set ("TEST_MODEL_SERVICE_IP" , svc_addr_map [TEST_MODEL_SERVICE ]["address" ])
3527
36- configure_environment (postgres , rabbitmq , minio )
28+ mlflow_addr = svc_addr_map ["mlflow-ui" ]["address" ]
29+ mlflow_port = svc_addr_map ["mlflow-ui" ]["port" ]
30+ env = {
31+ "MLFLOW_TRACKING_URI" : f"http://{ mlflow_addr } :{ mlflow_port } " ,
32+ }
33+ configure_environment (postgres , rabbitmq , minio , extras = env )
3734
38- scheduler_process = start_scheduler ()
39- request .addfinalizer (lambda : stop_scheduler (scheduler_process ))
40-
41- clone_cogstack_model_serve ()
42- request .addfinalizer (remove_cogstack_model_serve )
43-
44- cms_compose_envs = start_cogstack_model_serve ()
45- request .addfinalizer (lambda : stop_cogstack_model_serve (cms_compose_envs ))
35+ setup_scheduler (request )
4636
4737
4838@pytest .fixture (scope = "module" )
@@ -51,7 +41,165 @@ def client():
5141 yield client
5242
5343
44+ @pytest .fixture (scope = "module" )
45+ def config (client : TestClient ) -> Config :
46+ return load_config ()
47+
48+
49+ @pytest .fixture (scope = "module" )
50+ def test_model_service_ip (request : pytest .FixtureRequest ) -> str :
51+ return request .config .cache .get ("TEST_MODEL_SERVICE_IP" , None )
52+
53+
54+ def test_config_loaded (config : Config ):
55+ assert config
56+ assert all (
57+ key in config
58+ for key in [
59+ "database_manager" ,
60+ "task_object_store_manager" ,
61+ "results_object_store_manager" ,
62+ "queue_manager" ,
63+ "task_manager" ,
64+ ]
65+ )
66+
67+
5468def test_root (client : TestClient ):
5569 response = client .get ("/" )
5670 assert response .status_code == 200
5771 assert response .json () == {"message" : "Enter the cult... I mean, the API." }
72+
73+
74+ def test_get_tasks (client : TestClient ):
75+ response = client .get ("/tasks/" )
76+ assert response .status_code == 403
77+ assert response .json () == {"detail" : "Only admins can list tasks" }
78+
79+
80+ def test_get_task_by_uuid (client : TestClient , config : Config ):
81+ task_uuid = "nonexistent-uuid"
82+ response = client .get (f"/tasks/{ task_uuid } " )
83+ assert response .status_code == 404
84+ assert response .json () == {"detail" : f"Task '{ task_uuid } ' not found" }
85+
86+ tm : TaskManager = config .task_manager
87+ task_uuid = tm .create_task (status = "pending" )
88+ response = client .get (f"/tasks/{ task_uuid } " )
89+ assert response .status_code == 200
90+ assert response .json () == {"uuid" : task_uuid , "status" : "pending" }
91+
92+ tm .update_task (task_uuid , status = "succeeded" , result = "result.txt" , error_message = None )
93+ response = client .get (f"/tasks/{ task_uuid } " , params = {"detail" : True })
94+ assert response .status_code == 200
95+ assert response .json () == {
96+ "uuid" : task_uuid ,
97+ "status" : "succeeded" ,
98+ "result" : "result.txt" ,
99+ "error_message" : None ,
100+ "tracking_id" : None ,
101+ }
102+
103+
104+ def test_get_models (client : TestClient ):
105+ response = client .get ("/models/" )
106+ assert response .status_code == 200
107+
108+ response_json = response .json ()
109+ assert isinstance (response_json , list )
110+ assert len (response_json ) == 1
111+ assert all (key in response_json [0 ] for key in ["name" , "uri" ])
112+ assert response_json [0 ]["name" ] == TEST_MODEL_SERVICE
113+
114+
115+ def test_get_model_info (client : TestClient , test_model_service_ip : str ):
116+ response = client .get (f"/models/{ test_model_service_ip } /info" )
117+ assert response .status_code == 200
118+ assert all (
119+ key in response .json ()
120+ for key in ["api_version" , "model_type" , "model_description" , "model_card" ]
121+ )
122+
123+
124+ def test_unsupported_task (client : TestClient , test_model_service_ip : str ):
125+ response = client .post (
126+ f"/models/{ test_model_service_ip } /unsupported-task" ,
127+ headers = {"Content-Type" : "dummy" },
128+ )
129+ assert response .status_code == 404
130+ assert "Task 'unsupported-task' not found. Supported tasks are:" in response .json ()["detail" ]
131+
132+
133+ def test_process (client : TestClient , config : Config , test_model_service_ip : str ):
134+ response = client .post (
135+ f"/models/{ test_model_service_ip } /process" ,
136+ data = "Spinal stenosis" ,
137+ headers = {"Content-Type" : "text/plain" },
138+ )
139+ assert response .status_code == 200
140+ response_json = response .json ()
141+ assert all (key in response_json for key in ["uuid" , "status" ])
142+
143+ task_uuid = response_json ["uuid" ]
144+ tm : TaskManager = config .task_manager
145+ assert tm .get_task (task_uuid ), "Failed to submit task: not found in the database"
146+
147+ # Wait for the task to complete
148+ while (task := tm .get_task (task_uuid )).status != Status .SUCCEEDED :
149+ pass
150+
151+ # Verify that the task payload was stored in the object store
152+ task_payload_key = f"{ task_uuid } _payload.txt"
153+ tom : ObjectStoreManager = config .task_object_store_manager
154+ payload = tom .get_object (task_payload_key )
155+ assert payload == b"Spinal stenosis"
156+
157+ # Verify that the queue is empty after the task is processed
158+ qm : QueueManager = config .queue_manager
159+ assert qm .is_queue_empty ()
160+
161+ # Verify task results
162+ assert task .error_message is None , f"Task failed unexpectedly: { task .error_message } "
163+ assert task .result is not None , "Task results are missing"
164+
165+ rom : ObjectStoreManager = config .results_object_store_manager
166+ result = rom .get_object (task .result )
167+
168+ try :
169+ result_json = json .loads (result .decode ("utf-8" ))
170+ except json .JSONDecodeError as e :
171+ pytest .fail (f"Failed to parse the result as JSON: { result } , { e } " )
172+
173+ assert result_json ["text" ] == "Spinal stenosis"
174+ assert len (result_json ["annotations" ]) == 1
175+
176+ annotation = result_json ["annotations" ][0 ]
177+ assert all (
178+ key in annotation
179+ for key in [
180+ "start" ,
181+ "end" ,
182+ "label_name" ,
183+ "label_id" ,
184+ "categories" ,
185+ "accuracy" ,
186+ "meta_anns" ,
187+ "athena_ids" ,
188+ ]
189+ )
190+ assert annotation ["label_name" ] == "Spinal Stenosis"
191+
192+ # Verify that the above match the information exposed through the user-facing API
193+ get_response = client .get (f"/tasks/{ task_uuid } " , params = {"detail" : True , "download_url" : True })
194+ assert get_response .status_code == 200
195+
196+ get_response_json = get_response .json ()
197+ assert get_response_json ["uuid" ] == task .uuid
198+ assert get_response_json ["status" ] == task .status
199+ assert get_response_json ["error_message" ] is None
200+ assert get_response_json ["tracking_id" ] is None
201+
202+ # Download results and verify they match the ones read from the object store
203+ download_results = requests .get (get_response_json ["result" ])
204+ assert download_results .status_code == 200
205+ assert download_results .content == result
0 commit comments