Skip to content

Commit 8277f29

Browse files
committed
refactor: streamline service ports retrieval tests and enhance permission checks
1 parent 9666345 commit 8277f29

File tree

1 file changed

+18
-25
lines changed

1 file changed

+18
-25
lines changed

services/catalog/tests/unit/with_dbs/test_api_rpc.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -584,25 +584,15 @@ async def test_rpc_get_service_ports_successful_retrieval(
584584
product_name: ProductName,
585585
user_id: UserID,
586586
app: FastAPI,
587-
create_fake_service_data: Callable,
588-
services_db_tables_injector: Callable,
587+
expected_director_rest_api_list_services: list[dict[str, Any]],
589588
):
590589
"""Tests successful retrieval of service ports for a specific service version"""
591590
assert app
592591

593592
# Create a service with known ports
594-
service_key = "simcore/services/comp/test-service-ports"
595-
service_version = "1.0.0"
596-
597-
# Create and inject the service
598-
fake_service = create_fake_service_data(
599-
service_key,
600-
service_version,
601-
team_access=None,
602-
everyone_access=None,
603-
product=product_name,
604-
)
605-
await services_db_tables_injector([fake_service])
593+
expected_service = expected_director_rest_api_list_services[0]
594+
service_key = expected_service["key"]
595+
service_version = expected_service["version"]
606596

607597
# Call the RPC function to get service ports
608598
ports = await catalog_rpc.get_service_ports(
@@ -614,12 +604,9 @@ async def test_rpc_get_service_ports_successful_retrieval(
614604
)
615605

616606
# Validate the response
617-
assert isinstance(ports, list)
618-
# Each port should have expected fields
619-
for port in ports:
620-
assert hasattr(port, "kind")
621-
assert hasattr(port, "key")
622-
assert hasattr(port, "port")
607+
expected_inputs = expected_service["inputs"]
608+
expected_outputs = expected_service["outputs"]
609+
assert len(ports) == len(expected_inputs) + len(expected_outputs)
623610

624611

625612
async def test_rpc_get_service_ports_not_found(
@@ -654,13 +641,17 @@ async def test_rpc_get_service_ports_permission_denied(
654641
product_name: ProductName,
655642
user: dict[str, Any],
656643
user_id: UserID,
644+
other_user: dict[str, Any],
657645
app: FastAPI,
658646
create_fake_service_data: Callable,
659647
services_db_tables_injector: Callable,
660648
):
661649
"""Tests that appropriate error is raised when user doesn't have permission"""
662650
assert app
663651

652+
assert other_user["id"] != user_id
653+
assert user["id"] == user_id
654+
664655
# Create a service with restricted access
665656
restricted_service_key = "simcore/services/comp/restricted-service"
666657
service_version = "1.0.0"
@@ -674,10 +665,12 @@ async def test_rpc_get_service_ports_permission_denied(
674665
)
675666

676667
# Modify access rights to restrict access
677-
if "access_rights" in fake_restricted_service:
678-
# Remove user's access if present
679-
if user["primary_gid"] in fake_restricted_service["access_rights"]:
680-
fake_restricted_service["access_rights"].pop(user["primary_gid"])
668+
# Remove user's access if present
669+
if (
670+
"access_rights" in fake_restricted_service
671+
and user["primary_gid"] in fake_restricted_service["access_rights"]
672+
):
673+
fake_restricted_service["access_rights"].pop(user["primary_gid"])
681674

682675
await services_db_tables_injector([fake_restricted_service])
683676

@@ -686,7 +679,7 @@ async def test_rpc_get_service_ports_permission_denied(
686679
await catalog_rpc.get_service_ports(
687680
rpc_client,
688681
product_name=product_name,
689-
user_id=UserID("different-user"), # Use a different user ID
682+
user_id=other_user["id"], # Use a different user ID
690683
service_key=restricted_service_key,
691684
service_version=service_version,
692685
)

0 commit comments

Comments
 (0)