Skip to content

Commit 75e4088

Browse files
committed
✨ Implement service type filtering in latest services queries
1 parent c4592ff commit 75e4088

File tree

3 files changed

+175
-12
lines changed

3 files changed

+175
-12
lines changed

services/catalog/src/simcore_service_catalog/repository/_services_sql.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from sqlalchemy.sql.expression import func
2020
from sqlalchemy.sql.selectable import Select
2121

22-
from ..models.services_db import ServiceMetaDataDBGet
22+
from ..models.services_db import ServiceFiltersDB, ServiceMetaDataDBGet
2323

2424
SERVICES_META_DATA_COLS = get_columns_from_db_model(
2525
services_meta_data, ServiceMetaDataDBGet
@@ -113,13 +113,22 @@ def _has_access_rights(
113113
)
114114

115115

116+
def apply_services_filters(
117+
stmt,
118+
filters: ServiceFiltersDB,
119+
):
120+
if filters.service_type:
121+
stmt = stmt.where(services_meta_data.c.service_type == filters.service_type)
122+
123+
116124
def latest_services_total_count_stmt(
117125
*,
118126
product_name: ProductName,
119127
user_id: UserID,
120128
access_rights: sa.sql.ClauseElement,
129+
filters: ServiceFiltersDB | None = None,
121130
):
122-
return (
131+
stmt = (
123132
sa.select(func.count(sa.distinct(services_meta_data.c.key)))
124133
.select_from(
125134
services_meta_data.join(
@@ -136,6 +145,11 @@ def latest_services_total_count_stmt(
136145
.where(access_rights)
137146
)
138147

148+
if filters:
149+
apply_services_filters(stmt, filters)
150+
151+
return stmt
152+
139153

140154
def list_latest_services_stmt(
141155
*,
@@ -144,10 +158,11 @@ def list_latest_services_stmt(
144158
access_rights: sa.sql.ClauseElement,
145159
limit: int | None,
146160
offset: int | None,
161+
filters: ServiceFiltersDB | None = None,
147162
):
148163
# get all distinct services key fitting a page
149164
# and its corresponding latest version
150-
cte = (
165+
cte_stmt = (
151166
sa.select(
152167
services_meta_data.c.key,
153168
services_meta_data.c.version.label("latest_version"),
@@ -172,9 +187,13 @@ def list_latest_services_stmt(
172187
.distinct(services_meta_data.c.key) # get only first
173188
.limit(limit)
174189
.offset(offset)
175-
.cte("cte")
176190
)
177191

192+
if filters:
193+
apply_services_filters(cte_stmt, filters)
194+
195+
cte = cte_stmt.cte("cte")
196+
178197
# get all information of latest's services listed in CTE
179198
latest_stmt = (
180199
sa.select(

services/catalog/src/simcore_service_catalog/repository/services.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from simcore_postgres_database.models.services_specifications import (
2929
services_specifications,
3030
)
31-
from simcore_postgres_database.utils import as_postgres_sql_query_str
3231
from simcore_postgres_database.utils_repos import pass_or_acquire_connection
3332
from simcore_postgres_database.utils_services import create_select_latest_services_query
3433
from sqlalchemy import sql
@@ -37,6 +36,7 @@
3736
from ..models.services_db import (
3837
ReleaseDBGet,
3938
ServiceAccessRightsAtDB,
39+
ServiceFiltersDB,
4040
ServiceMetaDataDBCreate,
4141
ServiceMetaDataDBGet,
4242
ServiceMetaDataDBPatch,
@@ -47,6 +47,7 @@
4747
from ._services_sql import (
4848
SERVICES_META_DATA_COLS,
4949
AccessRightsClauses,
50+
apply_services_filters,
5051
by_version,
5152
can_get_service_stmt,
5253
get_service_history_stmt,
@@ -391,20 +392,23 @@ async def list_latest_services(
391392
# list args: pagination
392393
limit: int | None = None,
393394
offset: int | None = None,
395+
filters: ServiceFiltersDB | None = None,
394396
) -> tuple[PositiveInt, list[ServiceWithHistoryDBGet]]:
395397

396398
# get page
397399
stmt_total = latest_services_total_count_stmt(
398400
product_name=product_name,
399401
user_id=user_id,
400402
access_rights=AccessRightsClauses.can_read,
403+
filters=filters,
401404
)
402405
stmt_page = list_latest_services_stmt(
403406
product_name=product_name,
404407
user_id=user_id,
405408
access_rights=AccessRightsClauses.can_read,
406409
limit=limit,
407410
offset=offset,
411+
filters=filters,
408412
)
409413

410414
async with self.db_engine.connect() as conn:
@@ -480,9 +484,10 @@ async def get_service_history_page(
480484
# list args: pagination
481485
limit: int | None = None,
482486
offset: int | None = None,
487+
filters: ServiceFiltersDB | None = None,
483488
) -> tuple[PositiveInt, list[ReleaseDBGet]]:
484489

485-
base_subquery = (
490+
base_stmt = (
486491
# Search on service (key, *) for (product_name, user_id w/ access)
487492
sql.select(
488493
services_meta_data.c.key,
@@ -506,11 +511,15 @@ async def get_service_history_page(
506511
& (user_to_groups.c.uid == user_id)
507512
& AccessRightsClauses.can_read
508513
)
509-
).subquery()
514+
)
515+
516+
if filters:
517+
apply_services_filters(base_stmt, filters)
518+
519+
base_subquery = base_stmt.subquery()
510520

511521
# Query to count the TOTAL number of rows
512522
count_query = sql.select(sql.func.count()).select_from(base_subquery)
513-
_logger.debug("count_query=\n%s", as_postgres_sql_query_str(count_query))
514523

515524
# Query to retrieve page with additional columns, ordering, offset, and limit
516525
page_query = (
@@ -541,7 +550,6 @@ async def get_service_history_page(
541550
.offset(offset)
542551
.limit(limit)
543552
)
544-
_logger.debug("page_query=\n%s", as_postgres_sql_query_str(page_query))
545553

546554
async with pass_or_acquire_connection(self.db_engine) as conn:
547555
total_count: PositiveInt = await conn.scalar(count_query) or 0

services/catalog/tests/unit/with_dbs/test_repositories.py

Lines changed: 139 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import pytest
1515
from models_library.products import ProductName
16+
from models_library.services_enums import ServiceType # Import ServiceType enum
1617
from models_library.users import UserID
1718
from packaging import version
1819
from pydantic import EmailStr, HttpUrl, TypeAdapter
@@ -21,6 +22,7 @@
2122
from simcore_postgres_database.models.projects import ProjectType, projects
2223
from simcore_service_catalog.models.services_db import (
2324
ServiceAccessRightsAtDB,
25+
ServiceFiltersDB,
2426
ServiceMetaDataDBCreate,
2527
ServiceMetaDataDBGet,
2628
ServiceMetaDataDBPatch,
@@ -310,7 +312,7 @@ async def test_get_latest_release(
310312
assert latest.version == fake_catalog_with_jupyterlab.expected_latest
311313

312314

313-
async def test_list_all_services_and_history(
315+
async def test_list_latest_services(
314316
target_product: ProductName,
315317
user_id: UserID,
316318
services_repo: ServicesRepository,
@@ -332,7 +334,7 @@ async def test_list_all_services_and_history(
332334
), "list_latest_service does NOT show history"
333335

334336

335-
async def test_listing_with_no_services(
337+
async def test_list_latest_services_with_no_services(
336338
target_product: ProductName,
337339
services_repo: ServicesRepository,
338340
user_id: UserID,
@@ -344,7 +346,7 @@ async def test_listing_with_no_services(
344346
assert total_count == 0
345347

346348

347-
async def test_list_all_services_and_history_with_pagination(
349+
async def test_list_latest_services_with_pagination(
348350
target_product: ProductName,
349351
create_fake_service_data: Callable,
350352
services_db_tables_injector: Callable,
@@ -403,6 +405,64 @@ async def test_list_all_services_and_history_with_pagination(
403405
), f"list of latest versions of services cannot have duplicates, found: {duplicates}"
404406

405407

408+
async def test_list_latest_services_with_filters(
409+
target_product: ProductName,
410+
create_fake_service_data: Callable,
411+
services_db_tables_injector: Callable,
412+
services_repo: ServicesRepository,
413+
user_id: UserID,
414+
):
415+
# Setup: Inject services with different service types
416+
await services_db_tables_injector(
417+
[
418+
create_fake_service_data(
419+
f"simcore/services/dynamic/service-type-a-{i}",
420+
"1.0.0",
421+
team_access=None,
422+
everyone_access=None,
423+
product=target_product,
424+
service_type=ServiceType.DYNAMIC.value,
425+
)
426+
for i in range(3)
427+
]
428+
+ [
429+
create_fake_service_data(
430+
f"simcore/services/dynamic/service-type-b-{i}",
431+
"1.0.0",
432+
team_access=None,
433+
everyone_access=None,
434+
product=target_product,
435+
service_type=ServiceType.COMPUTATIONAL.value,
436+
)
437+
for i in range(2)
438+
]
439+
)
440+
441+
# Test: Apply filter for service_type=ServiceType.DYNAMIC
442+
filters = ServiceFiltersDB(service_type=ServiceType.DYNAMIC)
443+
total_count, services_items = await services_repo.list_latest_services(
444+
product_name=target_product, user_id=user_id, filters=filters
445+
)
446+
assert total_count == 3
447+
assert len(services_items) == 3
448+
assert all(
449+
service.key.startswith("simcore/services/dynamic/service-type-a")
450+
for service in services_items
451+
)
452+
453+
# Test: Apply filter for service_type=ServiceType.COMPUTATIONAL
454+
filters = ServiceFiltersDB(service_type=ServiceType.COMPUTATIONAL)
455+
total_count, services_items = await services_repo.list_latest_services(
456+
product_name=target_product, user_id=user_id, filters=filters
457+
)
458+
assert total_count == 2
459+
assert len(services_items) == 2
460+
assert all(
461+
service.key.startswith("simcore/services/dynamic/service-type-b")
462+
for service in services_items
463+
)
464+
465+
406466
async def test_get_and_update_service_meta_data(
407467
target_product: ProductName,
408468
create_fake_service_data: Callable,
@@ -566,6 +626,82 @@ async def test_get_service_history_page(
566626
assert paginated_history == history[offset : offset + limit]
567627

568628

629+
async def test_get_service_history_page_with_filters(
630+
target_product: ProductName,
631+
create_fake_service_data: Callable,
632+
services_db_tables_injector: Callable,
633+
services_repo: ServicesRepository,
634+
user_id: UserID,
635+
):
636+
# Setup: Inject services with multiple versions and types
637+
service_key = "simcore/services/dynamic/test-service"
638+
num_versions = 10
639+
640+
release_versions = set()
641+
while len(release_versions) < num_versions:
642+
release_versions.add(
643+
f"{random.randint(0, 2)}.{random.randint(0, 9)}.{random.randint(0, 9)}" # noqa: S311
644+
)
645+
646+
await services_db_tables_injector(
647+
[
648+
create_fake_service_data(
649+
service_key,
650+
service_version,
651+
team_access=None,
652+
everyone_access=None,
653+
product=target_product,
654+
service_type=(
655+
ServiceType.DYNAMIC.value
656+
if i % 2 == 0
657+
else ServiceType.COMPUTATIONAL.value
658+
),
659+
)
660+
for i, service_version in enumerate(release_versions)
661+
]
662+
)
663+
# Sort versions after injecting
664+
release_versions = sorted(release_versions, key=version.Version, reverse=True)
665+
666+
# Test: Fetch full history with no filters
667+
total_count, history = await services_repo.get_service_history_page(
668+
product_name=target_product,
669+
user_id=user_id,
670+
key=service_key,
671+
)
672+
assert total_count == num_versions
673+
assert len(history) == num_versions
674+
assert [release.version for release in history] == release_versions
675+
676+
# Test: Apply filter for service_type=ServiceType.DYNAMIC
677+
filters = ServiceFiltersDB(service_type=ServiceType.DYNAMIC)
678+
total_count, filtered_history = await services_repo.get_service_history_page(
679+
product_name=target_product,
680+
user_id=user_id,
681+
key=service_key,
682+
filters=filters,
683+
)
684+
assert total_count == num_versions // 2
685+
assert len(filtered_history) == num_versions // 2
686+
assert all(
687+
int(release.version.split(".")[0]) % 2 == 0 for release in filtered_history
688+
)
689+
690+
# Test: Apply filter for service_type=ServiceType.COMPUTATIONAL
691+
filters = ServiceFiltersDB(service_type=ServiceType.COMPUTATIONAL)
692+
total_count, filtered_history = await services_repo.get_service_history_page(
693+
product_name=target_product,
694+
user_id=user_id,
695+
key=service_key,
696+
filters=filters,
697+
)
698+
assert total_count == num_versions // 2
699+
assert len(filtered_history) == num_versions // 2
700+
assert all(
701+
int(release.version.split(".")[0]) % 2 != 0 for release in filtered_history
702+
)
703+
704+
569705
async def test_list_services_from_published_templates(
570706
user: dict[str, Any],
571707
projects_repo: ProjectsRepository,

0 commit comments

Comments
 (0)