Skip to content

Commit 0d00225

Browse files
committed
♻️ Refactor: migrate to asyncpg and improve database connection handling in studies dispatcher
1 parent 86003e4 commit 0d00225

File tree

4 files changed

+163
-24
lines changed

4 files changed

+163
-24
lines changed

services/web/server/src/simcore_service_webserver/studies_dispatcher/_catalog.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
from dataclasses import dataclass
55

66
import sqlalchemy as sa
7-
from aiohttp import web
8-
from aiopg.sa.connection import SAConnection
9-
from aiopg.sa.engine import Engine
107
from models_library.groups import EVERYONE_GROUP_ID
118
from models_library.services import ServiceKey, ServiceVersion
129
from models_library.services_constants import (
@@ -22,11 +19,12 @@
2219
from simcore_postgres_database.models.services_consume_filetypes import (
2320
services_consume_filetypes,
2421
)
22+
from simcore_postgres_database.utils_repos import pass_or_acquire_connection
2523
from simcore_postgres_database.utils_services import create_select_latest_services_query
24+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
2625

27-
from ..db.plugin import get_database_engine_legacy
2826
from ._errors import ServiceNotFoundError
29-
from .settings import StudiesDispatcherSettings, get_plugin_settings
27+
from .settings import StudiesDispatcherSettings
3028

3129
LARGEST_PAGE_SIZE = 1000
3230

@@ -44,22 +42,28 @@ class ServiceMetaData:
4442
file_extensions: list[str]
4543

4644

47-
async def _get_service_filetypes(conn: SAConnection) -> dict[ServiceKey, list[str]]:
45+
async def _get_service_filetypes(
46+
engine: AsyncEngine,
47+
connection: AsyncConnection | None = None,
48+
) -> dict[ServiceKey, list[str]]:
4849
query = sa.select(
4950
services_consume_filetypes.c.service_key,
5051
sa.func.array_agg(
5152
sa.func.distinct(services_consume_filetypes.c.filetype)
5253
).label("list_of_file_types"),
5354
).group_by(services_consume_filetypes.c.service_key)
5455

55-
result = await conn.execute(query)
56-
rows = await result.fetchall()
56+
async with pass_or_acquire_connection(engine, connection) as conn:
57+
result = await conn.execute(query)
58+
rows = result.fetchall()
5759

58-
return {row.service_key: row.list_of_file_types for row in rows}
60+
return {row.service_key: row.list_of_file_types for row in rows}
5961

6062

6163
async def iter_latest_product_services(
62-
app: web.Application,
64+
settings: StudiesDispatcherSettings,
65+
engine: AsyncEngine,
66+
connection: AsyncConnection | None = None,
6367
*,
6468
product_name: str,
6569
page_number: PositiveInt = 1, # 1-based
@@ -68,9 +72,6 @@ async def iter_latest_product_services(
6872
assert page_number >= 1 # nosec
6973
assert ((page_number - 1) * page_size) >= 0 # nosec
7074

71-
engine: Engine = get_database_engine_legacy(app)
72-
settings: StudiesDispatcherSettings = get_plugin_settings(app)
73-
7475
# Select query for latest version of the service
7576
latest_services = create_select_latest_services_query().alias("latest_services")
7677

@@ -109,10 +110,10 @@ async def iter_latest_product_services(
109110
# pagination
110111
query = query.limit(page_size).offset((page_number - 1) * page_size)
111112

112-
async with engine.acquire() as conn:
113-
service_filetypes = await _get_service_filetypes(conn)
113+
async with pass_or_acquire_connection(engine, connection) as conn:
114+
service_filetypes = await _get_service_filetypes(engine, conn)
114115

115-
async for row in await conn.execute(query):
116+
async for row in await conn.stream(query):
116117
yield ServiceMetaData(
117118
key=row.key,
118119
version=row.version,
@@ -135,14 +136,13 @@ class ValidService:
135136

136137
@log_decorator(_logger, level=logging.DEBUG)
137138
async def validate_requested_service(
138-
app: web.Application,
139+
engine: AsyncEngine,
140+
connection: AsyncConnection | None = None,
139141
*,
140142
service_key: ServiceKey,
141143
service_version: ServiceVersion,
142144
) -> ValidService:
143-
engine: Engine = get_database_engine_legacy(app)
144-
145-
async with engine.acquire() as conn:
145+
async with pass_or_acquire_connection(engine, connection) as conn:
146146
query = sa.select(
147147
services_meta_data.c.name,
148148
services_meta_data.c.key,
@@ -153,7 +153,7 @@ async def validate_requested_service(
153153
)
154154

155155
result = await conn.execute(query)
156-
row = await result.fetchone()
156+
row = result.one_or_none()
157157

158158
if row is None:
159159
raise ServiceNotFoundError(

services/web/server/src/simcore_service_webserver/studies_dispatcher/_controller/rest/nih.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
)
1010

1111
from ...._meta import API_VTAG
12+
from ....db.plugin import get_asyncpg_engine
1213
from ....products import products_web
1314
from ....utils_aiohttp import envelope_json_response
1415
from ... import _service
1516
from ..._catalog import iter_latest_product_services
17+
from ...settings import get_plugin_settings
1618
from .nih_schemas import ServiceGet, Viewer
1719

1820
_logger = logging.getLogger(__name__)
@@ -26,9 +28,12 @@ async def list_latest_services(request: Request):
2628
"""Returns a list latest version of services"""
2729
product_name = products_web.get_product_name(request)
2830

31+
plugin_settings = get_plugin_settings(request.app)
32+
engine = get_asyncpg_engine(request.app)
33+
2934
services = []
3035
async for service_data in iter_latest_product_services(
31-
request.app, product_name=product_name
36+
plugin_settings, engine, product_name=product_name
3237
):
3338
try:
3439
service = ServiceGet.create(service_data, request)

services/web/server/src/simcore_service_webserver/studies_dispatcher/_controller/rest/redirects.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from models_library.projects_nodes_io import NodeID
88
from servicelib.aiohttp.requests_validation import parse_request_query_parameters_as
99

10+
from ....db.plugin import get_asyncpg_engine
1011
from ....dynamic_scheduler import api as dynamic_scheduler_service
1112
from ....products import products_web
1213
from ....utils_aiohttp import create_redirect_to_page_response, get_api_base_url
@@ -133,7 +134,7 @@ async def get_redirection_to_viewer(request: web.Request):
133134
service_params_ = query_params
134135

135136
valid_service: ValidService = await validate_requested_service(
136-
app=request.app,
137+
get_asyncpg_engine(request.app),
137138
service_key=service_params_.viewer_key,
138139
service_version=service_params_.viewer_version,
139140
)

services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_repository.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,35 @@
77
from collections.abc import AsyncIterator
88

99
import pytest
10+
from models_library.groups import EVERYONE_GROUP_ID
11+
from models_library.services import ServiceKey, ServiceVersion
1012
from pytest_simcore.helpers.faker_factories import (
13+
random_service_access_rights,
1114
random_service_consume_filetype,
1215
random_service_meta_data,
1316
)
1417
from pytest_simcore.helpers.postgres_tools import insert_and_get_row_lifespan
15-
from simcore_postgres_database.models.services import services_meta_data
18+
from simcore_postgres_database.models.services import (
19+
services_access_rights,
20+
services_meta_data,
21+
)
1622
from simcore_postgres_database.models.services_consume_filetypes import (
1723
services_consume_filetypes,
1824
)
25+
from simcore_service_webserver.studies_dispatcher._catalog import (
26+
ServiceMetaData,
27+
ValidService,
28+
iter_latest_product_services,
29+
validate_requested_service,
30+
)
31+
from simcore_service_webserver.studies_dispatcher._errors import ServiceNotFoundError
1932
from simcore_service_webserver.studies_dispatcher._models import ViewerInfo
2033
from simcore_service_webserver.studies_dispatcher._repository import (
2134
StudiesDispatcherRepository,
2235
)
36+
from simcore_service_webserver.studies_dispatcher.settings import (
37+
StudiesDispatcherSettings,
38+
)
2339
from sqlalchemy.ext.asyncio import AsyncEngine
2440

2541

@@ -69,6 +85,37 @@ async def consume_filetypes_in_db(
6985
yield row
7086

7187

88+
@pytest.fixture
89+
async def service_access_rights_in_db(
90+
asyncpg_engine: AsyncEngine, service_metadata_in_db: dict
91+
):
92+
"""Pre-populate services access rights table with test data."""
93+
access_data = random_service_access_rights(
94+
key=service_metadata_in_db["key"],
95+
version=service_metadata_in_db["version"],
96+
gid=EVERYONE_GROUP_ID,
97+
execute_access=True,
98+
product_name="osparc",
99+
)
100+
101+
# pylint: disable=contextmanager-generator-missing-cleanup
102+
async with insert_and_get_row_lifespan(
103+
asyncpg_engine,
104+
table=services_access_rights,
105+
values=access_data,
106+
pk_col=services_access_rights.c.key,
107+
pk_value=access_data["key"],
108+
) as row:
109+
yield row
110+
111+
112+
@pytest.fixture
113+
def studies_dispatcher_settings() -> StudiesDispatcherSettings:
114+
return StudiesDispatcherSettings(
115+
STUDIES_DEFAULT_SERVICE_THUMBNAIL="https://example.com/default-thumbnail.png"
116+
)
117+
118+
72119
@pytest.fixture
73120
def studies_dispatcher_repository(
74121
asyncpg_engine: AsyncEngine,
@@ -201,3 +248,89 @@ async def test_find_compatible_viewer_not_found(
201248

202249
# Assert
203250
assert viewer_wrong_filetype is None
251+
252+
253+
async def test_iter_latest_product_services(
254+
asyncpg_engine: AsyncEngine,
255+
studies_dispatcher_settings: StudiesDispatcherSettings,
256+
service_metadata_in_db: dict,
257+
service_access_rights_in_db: dict,
258+
consume_filetypes_in_db: dict,
259+
):
260+
"""Test iterating through latest product services."""
261+
# Act
262+
services = []
263+
async for service in iter_latest_product_services(
264+
studies_dispatcher_settings, asyncpg_engine, product_name="osparc"
265+
):
266+
services.append(service)
267+
268+
# Assert
269+
assert len(services) == 1
270+
service = services[0]
271+
assert isinstance(service, ServiceMetaData)
272+
assert service.key == service_metadata_in_db["key"]
273+
assert service.version == service_metadata_in_db["version"]
274+
assert service.title == service_metadata_in_db["name"]
275+
assert service.description == service_metadata_in_db["description"]
276+
assert service.file_extensions == [consume_filetypes_in_db["filetype"]]
277+
278+
279+
async def test_iter_latest_product_services_with_pagination(
280+
asyncpg_engine: AsyncEngine,
281+
studies_dispatcher_settings: StudiesDispatcherSettings,
282+
service_metadata_in_db: dict,
283+
service_access_rights_in_db: dict,
284+
):
285+
"""Test iterating through services with pagination."""
286+
# Act
287+
services = []
288+
async for service in iter_latest_product_services(
289+
studies_dispatcher_settings,
290+
asyncpg_engine,
291+
product_name="osparc",
292+
page_number=1,
293+
page_size=1,
294+
):
295+
services.append(service)
296+
297+
# Assert
298+
assert len(services) == 1
299+
300+
301+
async def test_validate_requested_service_success(
302+
asyncpg_engine: AsyncEngine,
303+
service_metadata_in_db: dict,
304+
consume_filetypes_in_db: dict,
305+
):
306+
"""Test validating a service that exists and is valid."""
307+
# Act
308+
valid_service = await validate_requested_service(
309+
engine=asyncpg_engine,
310+
service_key=ServiceKey(service_metadata_in_db["key"]),
311+
service_version=ServiceVersion(service_metadata_in_db["version"]),
312+
)
313+
314+
# Assert
315+
assert isinstance(valid_service, ValidService)
316+
assert valid_service.key == service_metadata_in_db["key"]
317+
assert valid_service.version == service_metadata_in_db["version"]
318+
assert valid_service.title == service_metadata_in_db["name"]
319+
assert valid_service.is_public == consume_filetypes_in_db["is_guest_allowed"]
320+
assert valid_service.thumbnail is None # No valid URL in test data
321+
322+
323+
async def test_validate_requested_service_not_found(
324+
asyncpg_engine: AsyncEngine,
325+
):
326+
"""Test validating a service that doesn't exist."""
327+
# Act & Assert
328+
with pytest.raises(ServiceNotFoundError) as exc_info:
329+
await validate_requested_service(
330+
asyncpg_engine,
331+
service_key=ServiceKey("simcore/services/dynamic/nonexistent"),
332+
service_version=ServiceVersion("1.0.0"),
333+
)
334+
335+
assert exc_info.value.service_key == "simcore/services/dynamic/nonexistent"
336+
assert exc_info.value.service_version == "1.0.0"

0 commit comments

Comments
 (0)