Skip to content

Commit 88c1d57

Browse files
committed
WIP: Finally set up integration tests properly
Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent ac84fbb commit 88c1d57

File tree

6 files changed

+641
-319
lines changed

6 files changed

+641
-319
lines changed

poetry.lock

Lines changed: 268 additions & 251 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ select = ["E", "F", "I", "UP"]
4646
[tool.pytest.ini_options]
4747
addopts = "-ra"
4848
pythonpath = ["."]
49+
testpaths = ["tests"]
4950

5051
[build-system]
5152
requires = ["poetry-core"]

tests/conftest.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import logging
2+
3+
import pytest
4+
5+
6+
def pytest_addoption(parser: pytest.Parser) -> None:
7+
parser.addoption(
8+
"--skip-cleanup-cms",
9+
action="store_true",
10+
default=False,
11+
help="Skip cleanup for the CMS resources after completing the tests.",
12+
)
13+
14+
15+
@pytest.fixture(scope="module")
16+
def cleanup_cms(request: pytest.FixtureRequest) -> bool:
17+
return not request.config.getoption("--skip-cleanup-cms")
18+
19+
20+
@pytest.fixture(scope="session", autouse=True)
21+
def setup_logging() -> None:
22+
# Suppress logging from testcontainers
23+
for logger_name in logging.root.manager.loggerDict:
24+
if logger_name.startswith("testcontainers"):
25+
logging.getLogger(logger_name).setLevel(logging.WARNING)
26+
27+
parent_logger = logging.getLogger("cmg.tests")
28+
parent_logger.setLevel(logging.DEBUG)
29+
30+
handler = logging.StreamHandler()
31+
handler.setLevel(logging.DEBUG)
32+
handler.setFormatter(logging.Formatter("%(levelname)s:%(message)s"))
33+
parent_logger.addHandler(handler)
34+
35+
# Configure child loggers
36+
logging.getLogger("cmg.tests.integration").setLevel(logging.INFO)
37+
logging.getLogger("cmg.tests.unit").setLevel(logging.INFO)

tests/integration/assets/cms.env

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
COMPOSE_PROJECT_NAME=cmg-test
2+
3+
MLFLOW_DB_USERNAME=admin
4+
MLFLOW_DB_PASSWORD=admin
5+
AWS_ACCESS_KEY_ID=admin
6+
AWS_SECRET_ACCESS_KEY=admin123
7+
8+
GRAFANA_ADMIN_USER=admin
9+
GRAFANA_ADMIN_PASSWORD=admin
10+
11+
GRAYLOG_PASSWORD_SECRET=admin
12+
GRAYLOG_ROOT_PASSWORD_SHA2=admin

tests/integration/test_api.py

Lines changed: 180 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,38 @@
1+
import json
2+
13
import pytest
4+
import requests
25
from fastapi.testclient import TestClient
3-
from testcontainers.minio import MinioContainer
4-
from testcontainers.postgres import PostgresContainer
5-
from testcontainers.rabbitmq import RabbitMqContainer
66

7+
from cogstack_model_gateway.common.config import Config, load_config
8+
from cogstack_model_gateway.common.object_store import ObjectStoreManager
9+
from cogstack_model_gateway.common.queue import QueueManager
10+
from cogstack_model_gateway.common.tasks import Status, TaskManager
711
from cogstack_model_gateway.gateway.main import app
812
from tests.integration.utils import (
9-
clone_cogstack_model_serve,
13+
TEST_MODEL_SERVICE,
1014
configure_environment,
11-
remove_cogstack_model_serve,
12-
remove_testcontainers,
13-
start_cogstack_model_serve,
14-
start_scheduler,
15-
start_testcontainers,
16-
stop_cogstack_model_serve,
17-
stop_scheduler,
15+
setup_cms,
16+
setup_scheduler,
17+
setup_testcontainers,
1818
)
1919

20-
POSTGRES_IMAGE = "postgres:17.2"
21-
RABBITMQ_IMAGE = "rabbitmq:4.0.4-management-alpine"
22-
MINIO_IMAGE = "minio/minio:RELEASE.2024-11-07T00-52-20Z"
23-
2420

2521
@pytest.fixture(scope="module", autouse=True)
26-
def setup(request):
27-
postgres = PostgresContainer(POSTGRES_IMAGE)
28-
rabbitmq = RabbitMqContainer(RABBITMQ_IMAGE)
29-
minio = MinioContainer(MINIO_IMAGE)
30-
31-
containers = [postgres, rabbitmq, minio]
32-
request.addfinalizer(lambda: remove_testcontainers(containers))
22+
def setup(request: pytest.FixtureRequest, cleanup_cms: bool):
23+
postgres, rabbitmq, minio = setup_testcontainers(request)
3324

34-
start_testcontainers(containers)
25+
svc_addr_map = setup_cms(request, cleanup_cms)
26+
request.config.cache.set("TEST_MODEL_SERVICE_IP", svc_addr_map[TEST_MODEL_SERVICE]["address"])
3527

36-
configure_environment(postgres, rabbitmq, minio)
28+
mlflow_addr = svc_addr_map["mlflow-ui"]["address"]
29+
mlflow_port = svc_addr_map["mlflow-ui"]["port"]
30+
env = {
31+
"MLFLOW_TRACKING_URI": f"http://{mlflow_addr}:{mlflow_port}",
32+
}
33+
configure_environment(postgres, rabbitmq, minio, extras=env)
3734

38-
scheduler_process = start_scheduler()
39-
request.addfinalizer(lambda: stop_scheduler(scheduler_process))
40-
41-
clone_cogstack_model_serve()
42-
request.addfinalizer(remove_cogstack_model_serve)
43-
44-
cms_compose_envs = start_cogstack_model_serve()
45-
request.addfinalizer(lambda: stop_cogstack_model_serve(cms_compose_envs))
35+
setup_scheduler(request)
4636

4737

4838
@pytest.fixture(scope="module")
@@ -51,7 +41,165 @@ def client():
5141
yield client
5242

5343

44+
@pytest.fixture(scope="module")
45+
def config(client: TestClient) -> Config:
46+
return load_config()
47+
48+
49+
@pytest.fixture(scope="module")
50+
def test_model_service_ip(request: pytest.FixtureRequest) -> str:
51+
return request.config.cache.get("TEST_MODEL_SERVICE_IP", None)
52+
53+
54+
def test_config_loaded(config: Config):
55+
assert config
56+
assert all(
57+
key in config
58+
for key in [
59+
"database_manager",
60+
"task_object_store_manager",
61+
"results_object_store_manager",
62+
"queue_manager",
63+
"task_manager",
64+
]
65+
)
66+
67+
5468
def test_root(client: TestClient):
5569
response = client.get("/")
5670
assert response.status_code == 200
5771
assert response.json() == {"message": "Enter the cult... I mean, the API."}
72+
73+
74+
def test_get_tasks(client: TestClient):
75+
response = client.get("/tasks/")
76+
assert response.status_code == 403
77+
assert response.json() == {"detail": "Only admins can list tasks"}
78+
79+
80+
def test_get_task_by_uuid(client: TestClient, config: Config):
81+
task_uuid = "nonexistent-uuid"
82+
response = client.get(f"/tasks/{task_uuid}")
83+
assert response.status_code == 404
84+
assert response.json() == {"detail": f"Task '{task_uuid}' not found"}
85+
86+
tm: TaskManager = config.task_manager
87+
task_uuid = tm.create_task(status="pending")
88+
response = client.get(f"/tasks/{task_uuid}")
89+
assert response.status_code == 200
90+
assert response.json() == {"uuid": task_uuid, "status": "pending"}
91+
92+
tm.update_task(task_uuid, status="succeeded", result="result.txt", error_message=None)
93+
response = client.get(f"/tasks/{task_uuid}", params={"detail": True})
94+
assert response.status_code == 200
95+
assert response.json() == {
96+
"uuid": task_uuid,
97+
"status": "succeeded",
98+
"result": "result.txt",
99+
"error_message": None,
100+
"tracking_id": None,
101+
}
102+
103+
104+
def test_get_models(client: TestClient):
105+
response = client.get("/models/")
106+
assert response.status_code == 200
107+
108+
response_json = response.json()
109+
assert isinstance(response_json, list)
110+
assert len(response_json) == 1
111+
assert all(key in response_json[0] for key in ["name", "uri"])
112+
assert response_json[0]["name"] == TEST_MODEL_SERVICE
113+
114+
115+
def test_get_model_info(client: TestClient, test_model_service_ip: str):
116+
response = client.get(f"/models/{test_model_service_ip}/info")
117+
assert response.status_code == 200
118+
assert all(
119+
key in response.json()
120+
for key in ["api_version", "model_type", "model_description", "model_card"]
121+
)
122+
123+
124+
def test_unsupported_task(client: TestClient, test_model_service_ip: str):
125+
response = client.post(
126+
f"/models/{test_model_service_ip}/unsupported-task",
127+
headers={"Content-Type": "dummy"},
128+
)
129+
assert response.status_code == 404
130+
assert "Task 'unsupported-task' not found. Supported tasks are:" in response.json()["detail"]
131+
132+
133+
def test_process(client: TestClient, config: Config, test_model_service_ip: str):
134+
response = client.post(
135+
f"/models/{test_model_service_ip}/process",
136+
data="Spinal stenosis",
137+
headers={"Content-Type": "text/plain"},
138+
)
139+
assert response.status_code == 200
140+
response_json = response.json()
141+
assert all(key in response_json for key in ["uuid", "status"])
142+
143+
task_uuid = response_json["uuid"]
144+
tm: TaskManager = config.task_manager
145+
assert tm.get_task(task_uuid), "Failed to submit task: not found in the database"
146+
147+
# Wait for the task to complete
148+
while (task := tm.get_task(task_uuid)).status != Status.SUCCEEDED:
149+
pass
150+
151+
# Verify that the task payload was stored in the object store
152+
task_payload_key = f"{task_uuid}_payload.txt"
153+
tom: ObjectStoreManager = config.task_object_store_manager
154+
payload = tom.get_object(task_payload_key)
155+
assert payload == b"Spinal stenosis"
156+
157+
# Verify that the queue is empty after the task is processed
158+
qm: QueueManager = config.queue_manager
159+
assert qm.is_queue_empty()
160+
161+
# Verify task results
162+
assert task.error_message is None, f"Task failed unexpectedly: {task.error_message}"
163+
assert task.result is not None, "Task results are missing"
164+
165+
rom: ObjectStoreManager = config.results_object_store_manager
166+
result = rom.get_object(task.result)
167+
168+
try:
169+
result_json = json.loads(result.decode("utf-8"))
170+
except json.JSONDecodeError as e:
171+
pytest.fail(f"Failed to parse the result as JSON: {result}, {e}")
172+
173+
assert result_json["text"] == "Spinal stenosis"
174+
assert len(result_json["annotations"]) == 1
175+
176+
annotation = result_json["annotations"][0]
177+
assert all(
178+
key in annotation
179+
for key in [
180+
"start",
181+
"end",
182+
"label_name",
183+
"label_id",
184+
"categories",
185+
"accuracy",
186+
"meta_anns",
187+
"athena_ids",
188+
]
189+
)
190+
assert annotation["label_name"] == "Spinal Stenosis"
191+
192+
# Verify that the above match the information exposed through the user-facing API
193+
get_response = client.get(f"/tasks/{task_uuid}", params={"detail": True, "download_url": True})
194+
assert get_response.status_code == 200
195+
196+
get_response_json = get_response.json()
197+
assert get_response_json["uuid"] == task.uuid
198+
assert get_response_json["status"] == task.status
199+
assert get_response_json["error_message"] is None
200+
assert get_response_json["tracking_id"] is None
201+
202+
# Download results and verify they match the ones read from the object store
203+
download_results = requests.get(get_response_json["result"])
204+
assert download_results.status_code == 200
205+
assert download_results.content == result

0 commit comments

Comments
 (0)