diff --git a/packages/models-library/src/models_library/functions.py b/packages/models-library/src/models_library/functions.py index 3e85ad84017c..4ab6fe389b40 100644 --- a/packages/models-library/src/models_library/functions.py +++ b/packages/models-library/src/models_library/functions.py @@ -227,6 +227,8 @@ class FunctionJobDB(BaseModel): class_specific_data: FunctionJobClassSpecificData function_class: FunctionClass + model_config = ConfigDict(from_attributes=True) + class RegisteredFunctionJobDB(FunctionJobDB): uuid: FunctionJobID @@ -242,6 +244,8 @@ class FunctionDB(BaseModel): default_inputs: FunctionInputs class_specific_data: FunctionClassSpecificData + model_config = ConfigDict(from_attributes=True) + class RegisteredFunctionDB(FunctionDB): uuid: FunctionID @@ -252,6 +256,8 @@ class FunctionJobCollectionDB(BaseModel): title: str = "" description: str = "" + model_config = ConfigDict(from_attributes=True) + class RegisteredFunctionJobCollectionDB(FunctionJobCollectionDB): uuid: FunctionJobCollectionID diff --git a/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rpc.py b/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rpc.py index d51f0df87878..f77b07011150 100644 --- a/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rpc.py +++ b/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rpc.py @@ -22,6 +22,7 @@ FunctionJobCollectionReadAccessDeniedError, FunctionJobCollectionsReadApiAccessDeniedError, FunctionJobCollectionsWriteApiAccessDeniedError, + FunctionJobCollectionWriteAccessDeniedError, FunctionJobIDNotFoundError, FunctionJobReadAccessDeniedError, FunctionJobsReadApiAccessDeniedError, @@ -283,6 +284,7 @@ async def delete_function_job( reraise_if_error_type=( FunctionJobCollectionIDNotFoundError, FunctionJobCollectionReadAccessDeniedError, + FunctionJobCollectionWriteAccessDeniedError, FunctionJobCollectionsWriteApiAccessDeniedError, ) ) @@ -305,6 +307,7 @@ async def delete_function_job_collection( reraise_if_error_type=( FunctionIDNotFoundError, FunctionReadAccessDeniedError, + FunctionWriteAccessDeniedError, ) ) async def update_function_title( @@ -325,7 +328,11 @@ async def update_function_title( @router.expose( - reraise_if_error_type=(FunctionIDNotFoundError, FunctionReadAccessDeniedError) + reraise_if_error_type=( + FunctionIDNotFoundError, + FunctionReadAccessDeniedError, + FunctionWriteAccessDeniedError, + ) ) async def update_function_description( app: web.Application, diff --git a/services/web/server/src/simcore_service_webserver/functions/_functions_repository.py b/services/web/server/src/simcore_service_webserver/functions/_functions_repository.py index 607a75ae6251..8cabb0185a67 100644 --- a/services/web/server/src/simcore_service_webserver/functions/_functions_repository.py +++ b/services/web/server/src/simcore_service_webserver/functions/_functions_repository.py @@ -78,9 +78,11 @@ from simcore_postgres_database.models.funcapi_functions_table import functions_table from simcore_postgres_database.utils_repos import ( get_columns_from_db_model, + pass_or_acquire_connection, transaction_context, ) from sqlalchemy import Text, cast +from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.sql import func @@ -120,17 +122,16 @@ async def create_function( # noqa: PLR0913 output_schema: FunctionOutputSchema, default_inputs: FunctionInputs, ) -> RegisteredFunctionDB: + async with transaction_context(get_asyncpg_engine(app), connection) as transaction: + await check_user_api_access_rights( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + api_access_rights=[FunctionsApiAccessRights.WRITE_FUNCTIONS], + ) - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[FunctionsApiAccessRights.WRITE_FUNCTIONS], - ) - - async with transaction_context(get_asyncpg_engine(app), connection) as conn: - result = await conn.stream( + result = await transaction.execute( functions_table.insert() .values( title=title, @@ -143,24 +144,24 @@ async def create_function( # noqa: PLR0913 ) .returning(*_FUNCTIONS_TABLE_COLS) ) - row = await result.one() + row = result.one() - registered_function = RegisteredFunctionDB.model_validate(dict(row)) + registered_function = RegisteredFunctionDB.model_validate(row) - user_primary_group_id = await get_user_primary_group_id(app, user_id=user_id) - await set_group_permissions( - app, - connection=connection, - group_id=user_primary_group_id, - product_name=product_name, - object_type="function", - object_ids=[registered_function.uuid], - read=True, - write=True, - execute=True, - ) + user_primary_group_id = await get_user_primary_group_id(app, user_id=user_id) + await set_group_permissions( + app, + connection=transaction, + group_id=user_primary_group_id, + product_name=product_name, + object_type="function", + object_ids=[registered_function.uuid], + read=True, + write=True, + execute=True, + ) - return RegisteredFunctionDB.model_validate(dict(row)) + return RegisteredFunctionDB.model_validate(row) async def create_function_job( # noqa: PLR0913 @@ -177,18 +178,17 @@ async def create_function_job( # noqa: PLR0913 outputs: FunctionOutputs, class_specific_data: FunctionJobClassSpecificData, ) -> RegisteredFunctionJobDB: - - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[ - FunctionsApiAccessRights.WRITE_FUNCTION_JOBS, - ], - ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: - result = await conn.stream( + async with transaction_context(get_asyncpg_engine(app), connection) as transaction: + await check_user_api_access_rights( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + api_access_rights=[ + FunctionsApiAccessRights.WRITE_FUNCTION_JOBS, + ], + ) + result = await transaction.execute( function_jobs_table.insert() .values( function_uuid=function_uid, @@ -202,22 +202,22 @@ async def create_function_job( # noqa: PLR0913 ) .returning(*_FUNCTION_JOBS_TABLE_COLS) ) - row = await result.one() + row = result.one() - registered_function_job = RegisteredFunctionJobDB.model_validate(dict(row)) + registered_function_job = RegisteredFunctionJobDB.model_validate(row) - user_primary_group_id = await get_user_primary_group_id(app, user_id=user_id) - await set_group_permissions( - app, - connection=connection, - group_id=user_primary_group_id, - product_name=product_name, - object_type="function_job", - object_ids=[registered_function_job.uuid], - read=True, - write=True, - execute=True, - ) + user_primary_group_id = await get_user_primary_group_id(app, user_id=user_id) + await set_group_permissions( + app, + connection=transaction, + group_id=user_primary_group_id, + product_name=product_name, + object_type="function_job", + object_ids=[registered_function_job.uuid], + read=True, + write=True, + execute=True, + ) return registered_function_job @@ -232,29 +232,28 @@ async def create_function_job_collection( description: str, job_ids: list[FunctionJobID], ) -> tuple[RegisteredFunctionJobCollectionDB, list[FunctionJobID]]: - - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[ - FunctionsApiAccessRights.WRITE_FUNCTION_JOB_COLLECTIONS, - ], - ) - for job_id in job_ids: - await check_user_permissions( + async with transaction_context(get_asyncpg_engine(app), connection) as transaction: + await check_user_api_access_rights( app, - connection=connection, + connection=transaction, user_id=user_id, product_name=product_name, - object_type="function_job", - object_id=job_id, - permissions=["read"], + api_access_rights=[ + FunctionsApiAccessRights.WRITE_FUNCTION_JOB_COLLECTIONS, + ], ) + for job_id in job_ids: + await check_user_permissions( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + object_type="function_job", + object_id=job_id, + permissions=["read"], + ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: - result = await conn.stream( + result = await transaction.execute( function_job_collections_table.insert() .values( title=title, @@ -262,7 +261,7 @@ async def create_function_job_collection( ) .returning(*_FUNCTION_JOB_COLLECTIONS_TABLE_COLS) ) - row = await result.one_or_none() + row = result.one_or_none() assert row is not None, ( "No row was returned from the database after creating function job collection." @@ -270,11 +269,11 @@ async def create_function_job_collection( ) # nosec function_job_collection_db = RegisteredFunctionJobCollectionDB.model_validate( - dict(row) + row ) - job_collection_entries = [] + job_collection_entries: list[Row] = [] for job_id in job_ids: - result = await conn.stream( + result = await transaction.execute( function_job_collections_to_function_jobs_table.insert() .values( function_job_collection_uuid=function_job_collection_db.uuid, @@ -285,28 +284,28 @@ async def create_function_job_collection( function_job_collections_to_function_jobs_table.c.function_job_uuid, ) ) - entry = await result.one_or_none() + entry = result.one_or_none() assert entry is not None, ( f"No row was returned from the database after creating function job collection entry {title}." f" Job ID: {job_id}" ) # nosec - job_collection_entries.append(dict(entry)) + job_collection_entries.append(entry) - user_primary_group_id = await get_user_primary_group_id(app, user_id=user_id) - await set_group_permissions( - app, - connection=connection, - group_id=user_primary_group_id, - product_name=product_name, - object_type="function_job_collection", - object_ids=[function_job_collection_db.uuid], - read=True, - write=True, - execute=True, - ) + user_primary_group_id = await get_user_primary_group_id(app, user_id=user_id) + await set_group_permissions( + app, + connection=transaction, + group_id=user_primary_group_id, + product_name=product_name, + object_type="function_job_collection", + object_ids=[function_job_collection_db.uuid], + read=True, + write=True, + execute=True, + ) return function_job_collection_db, [ - dict(entry)["function_job_uuid"] for entry in job_collection_entries + entry.function_job_uuid for entry in job_collection_entries ] @@ -318,34 +317,33 @@ async def get_function( product_name: ProductName, function_id: FunctionID, ) -> RegisteredFunctionDB: + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + await check_user_api_access_rights( + app, + connection=conn, + user_id=user_id, + product_name=product_name, + api_access_rights=[FunctionsApiAccessRights.READ_FUNCTIONS], + ) - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[FunctionsApiAccessRights.READ_FUNCTIONS], - ) - - async with transaction_context(get_asyncpg_engine(app), connection) as conn: - result = await conn.stream( + result = await conn.execute( functions_table.select().where(functions_table.c.uuid == function_id) ) - row = await result.one_or_none() + row = result.one_or_none() if row is None: raise FunctionIDNotFoundError(function_id=function_id) - registered_function = RegisteredFunctionDB.model_validate(dict(row)) + registered_function = RegisteredFunctionDB.model_validate(row) - await check_user_permissions( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - object_id=function_id, - object_type="function", - permissions=["read"], - ) + await check_user_permissions( + app, + connection=conn, + user_id=user_id, + product_name=product_name, + object_id=function_id, + object_type="function", + permissions=["read"], + ) return registered_function @@ -359,15 +357,14 @@ async def list_functions( pagination_limit: int, pagination_offset: int, ) -> tuple[list[RegisteredFunctionDB], PageMetaInfoLimitOffset]: - - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[FunctionsApiAccessRights.READ_FUNCTIONS], - ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + await check_user_api_access_rights( + app, + connection=conn, + user_id=user_id, + product_name=product_name, + api_access_rights=[FunctionsApiAccessRights.READ_FUNCTIONS], + ) user_groups = await list_all_user_groups_ids(app, user_id=user_id) subquery = ( @@ -386,25 +383,25 @@ async def list_functions( .select_from(functions_table) .where(functions_table.c.uuid.in_(subquery)) ) - result = await conn.stream( - functions_table.select() - .where(functions_table.c.uuid.in_(subquery)) - .offset(pagination_offset) - .limit(pagination_limit) - ) - rows = await result.all() - if rows is None: + if total_count_result == 0: return [], PageMetaInfoLimitOffset( total=0, offset=pagination_offset, limit=pagination_limit, count=0 ) + function_rows = [ + RegisteredFunctionDB.model_validate(row) + async for row in await conn.stream( + functions_table.select() + .where(functions_table.c.uuid.in_(subquery)) + .offset(pagination_offset) + .limit(pagination_limit) + ) + ] - return [ - RegisteredFunctionDB.model_validate(dict(row)) for row in rows - ], PageMetaInfoLimitOffset( + return function_rows, PageMetaInfoLimitOffset( total=total_count_result, offset=pagination_offset, limit=pagination_limit, - count=len(rows), + count=len(function_rows), ) @@ -418,15 +415,14 @@ async def list_function_jobs( pagination_offset: int, filter_by_function_id: FunctionID | None = None, ) -> tuple[list[RegisteredFunctionJobDB], PageMetaInfoLimitOffset]: - - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[FunctionsApiAccessRights.READ_FUNCTION_JOBS], - ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + await check_user_api_access_rights( + app, + connection=conn, + user_id=user_id, + product_name=product_name, + api_access_rights=[FunctionsApiAccessRights.READ_FUNCTION_JOBS], + ) user_groups = await list_all_user_groups_ids(app, user_id=user_id) access_subquery = ( @@ -450,30 +446,30 @@ async def list_function_jobs( else sqlalchemy.sql.true() ) ) - result = await conn.stream( - function_jobs_table.select() - .where(function_jobs_table.c.uuid.in_(access_subquery)) - .where( - function_jobs_table.c.function_uuid == filter_by_function_id - if filter_by_function_id - else sqlalchemy.sql.true() - ) - .offset(pagination_offset) - .limit(pagination_limit) - ) - rows = await result.all() - if rows is None: + if total_count_result == 0: return [], PageMetaInfoLimitOffset( total=0, offset=pagination_offset, limit=pagination_limit, count=0 ) + results = [ + RegisteredFunctionJobDB.model_validate(row) + async for row in await conn.stream( + function_jobs_table.select() + .where(function_jobs_table.c.uuid.in_(access_subquery)) + .where( + function_jobs_table.c.function_uuid == filter_by_function_id + if filter_by_function_id + else sqlalchemy.sql.true() + ) + .offset(pagination_offset) + .limit(pagination_limit) + ) + ] - return [ - RegisteredFunctionJobDB.model_validate(dict(row)) for row in rows - ], PageMetaInfoLimitOffset( + return results, PageMetaInfoLimitOffset( total=total_count_result, offset=pagination_offset, limit=pagination_limit, - count=len(rows), + count=len(results), ) @@ -494,17 +490,17 @@ async def list_function_job_collections( Returns a list of function job collections and their associated job ids. Filters the collections to include only those that have function jobs with the specified function id if filters.has_function_id is provided. """ - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[ - FunctionsApiAccessRights.READ_FUNCTION_JOB_COLLECTIONS, - ], - ) + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + await check_user_api_access_rights( + app, + connection=conn, + user_id=user_id, + product_name=product_name, + api_access_rights=[ + FunctionsApiAccessRights.READ_FUNCTION_JOB_COLLECTIONS, + ], + ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: filter_condition: sqlalchemy.sql.ColumnElement = sqlalchemy.sql.true() if filters and filters.has_function_id: @@ -554,40 +550,35 @@ async def list_function_job_collections( .select_from(function_job_collections_table) .where(filter_and_access_condition) ) - query = function_job_collections_table.select().where( - filter_and_access_condition - ) - - result = await conn.stream( - query.offset(pagination_offset).limit(pagination_limit) - ) - rows = await result.all() - if rows is None: + if total_count_result == 0: return [], PageMetaInfoLimitOffset( total=0, offset=pagination_offset, limit=pagination_limit, count=0 ) + query = function_job_collections_table.select().where( + filter_and_access_condition + ) + collections = [] - for row in rows: - collection = RegisteredFunctionJobCollectionDB.model_validate(dict(row)) - job_result = await conn.stream( - function_job_collections_to_function_jobs_table.select().where( - function_job_collections_to_function_jobs_table.c.function_job_collection_uuid - == row["uuid"] + async for row in await conn.stream( + query.offset(pagination_offset).limit(pagination_limit) + ): + collection = RegisteredFunctionJobCollectionDB.model_validate(row) + job_ids = [ + job_row.function_job_uuid + async for job_row in await conn.stream( + function_job_collections_to_function_jobs_table.select().where( + function_job_collections_to_function_jobs_table.c.function_job_collection_uuid + == row.uuid + ) ) - ) - job_rows = await job_result.all() - job_ids = ( - [job_row["function_job_uuid"] for job_row in job_rows] - if job_rows - else [] - ) + ] collections.append((collection, job_ids)) return collections, PageMetaInfoLimitOffset( total=total_count_result, offset=pagination_offset, limit=pagination_limit, - count=len(rows), + count=len(collections), ) @@ -599,40 +590,39 @@ async def delete_function( product_name: ProductName, function_id: FunctionID, ) -> None: + async with transaction_context(get_asyncpg_engine(app), connection) as transaction: + await check_user_api_access_rights( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + api_access_rights=[ + FunctionsApiAccessRights.READ_FUNCTIONS, + FunctionsApiAccessRights.WRITE_FUNCTIONS, + ], + ) - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[ - FunctionsApiAccessRights.READ_FUNCTIONS, - FunctionsApiAccessRights.WRITE_FUNCTIONS, - ], - ) - - await check_user_permissions( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - object_id=function_id, - object_type="function", - permissions=["write"], - ) + await check_user_permissions( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + object_id=function_id, + object_type="function", + permissions=["write"], + ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: # Check if the function exists - result = await conn.stream( + result = await transaction.execute( functions_table.select().where(functions_table.c.uuid == function_id) ) - row = await result.one_or_none() + row = result.one_or_none() if row is None: raise FunctionIDNotFoundError(function_id=function_id) # Proceed with deletion - await conn.execute( + await transaction.execute( functions_table.delete().where(functions_table.c.uuid == function_id) ) @@ -646,40 +636,40 @@ async def update_function_title( function_id: FunctionID, title: str, ) -> RegisteredFunctionDB: + async with transaction_context(get_asyncpg_engine(app), connection) as transaction: + await check_user_api_access_rights( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + api_access_rights=[ + FunctionsApiAccessRights.READ_FUNCTIONS, + FunctionsApiAccessRights.WRITE_FUNCTIONS, + ], + ) - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[ - FunctionsApiAccessRights.READ_FUNCTIONS, - FunctionsApiAccessRights.WRITE_FUNCTIONS, - ], - ) - - await check_user_permissions( - app, - user_id=user_id, - product_name=product_name, - object_id=function_id, - object_type="function", - permissions=["write"], - ) + await check_user_permissions( + app, + transaction, + user_id=user_id, + product_name=product_name, + object_id=function_id, + object_type="function", + permissions=["write"], + ) - async with transaction_context(get_asyncpg_engine(app)) as conn: - result = await conn.stream( + result = await transaction.execute( functions_table.update() .where(functions_table.c.uuid == function_id) .values(title=title) .returning(*_FUNCTIONS_TABLE_COLS) ) - row = await result.one_or_none() + row = result.one_or_none() if row is None: raise FunctionIDNotFoundError(function_id=function_id) - return RegisteredFunctionDB.model_validate(dict(row)) + return RegisteredFunctionDB.model_validate(row) async def update_function_description( @@ -691,38 +681,39 @@ async def update_function_description( function_id: FunctionID, description: str, ) -> RegisteredFunctionDB: - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[ - FunctionsApiAccessRights.READ_FUNCTIONS, - FunctionsApiAccessRights.WRITE_FUNCTIONS, - ], - ) - await check_user_permissions( - app, - user_id=user_id, - product_name=product_name, - object_id=function_id, - object_type="function", - permissions=["write"], - ) + async with transaction_context(get_asyncpg_engine(app), connection) as transaction: + await check_user_api_access_rights( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + api_access_rights=[ + FunctionsApiAccessRights.READ_FUNCTIONS, + FunctionsApiAccessRights.WRITE_FUNCTIONS, + ], + ) + await check_user_permissions( + app, + transaction, + user_id=user_id, + product_name=product_name, + object_id=function_id, + object_type="function", + permissions=["write"], + ) - async with transaction_context(get_asyncpg_engine(app)) as conn: - result = await conn.stream( + result = await transaction.execute( functions_table.update() .where(functions_table.c.uuid == function_id) .values(description=description) .returning(*_FUNCTIONS_TABLE_COLS) ) - row = await result.one_or_none() + row = result.one_or_none() if row is None: raise FunctionIDNotFoundError(function_id=function_id) - return RegisteredFunctionDB.model_validate(dict(row)) + return RegisteredFunctionDB.model_validate(row) async def get_function_job( @@ -733,35 +724,35 @@ async def get_function_job( product_name: ProductName, function_job_id: FunctionID, ) -> RegisteredFunctionJobDB: - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[FunctionsApiAccessRights.READ_FUNCTION_JOBS], - ) - await check_user_permissions( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - object_id=function_job_id, - object_type="function_job", - permissions=["read"], - ) + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + await check_user_api_access_rights( + app, + connection=conn, + user_id=user_id, + product_name=product_name, + api_access_rights=[FunctionsApiAccessRights.READ_FUNCTION_JOBS], + ) + await check_user_permissions( + app, + connection=conn, + user_id=user_id, + product_name=product_name, + object_id=function_job_id, + object_type="function_job", + permissions=["read"], + ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: - result = await conn.stream( + result = await conn.execute( function_jobs_table.select().where( function_jobs_table.c.uuid == function_job_id ) ) - row = await result.one_or_none() + row = result.one_or_none() if row is None: raise FunctionJobIDNotFoundError(function_job_id=function_job_id) - return RegisteredFunctionJobDB.model_validate(dict(row)) + return RegisteredFunctionJobDB.model_validate(row) async def delete_function_job( @@ -772,39 +763,39 @@ async def delete_function_job( product_name: ProductName, function_job_id: FunctionID, ) -> None: - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[ - FunctionsApiAccessRights.READ_FUNCTION_JOBS, - FunctionsApiAccessRights.WRITE_FUNCTION_JOBS, - ], - ) - await check_user_permissions( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - object_id=function_job_id, - object_type="function_job", - permissions=["write"], - ) + async with transaction_context(get_asyncpg_engine(app), connection) as transaction: + await check_user_api_access_rights( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + api_access_rights=[ + FunctionsApiAccessRights.READ_FUNCTION_JOBS, + FunctionsApiAccessRights.WRITE_FUNCTION_JOBS, + ], + ) + await check_user_permissions( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + object_id=function_job_id, + object_type="function_job", + permissions=["write"], + ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: # Check if the function job exists - result = await conn.stream( + result = await transaction.execute( function_jobs_table.select().where( function_jobs_table.c.uuid == function_job_id ) ) - row = await result.one_or_none() + row = result.one_or_none() if row is None: raise FunctionJobIDNotFoundError(function_job_id=function_job_id) # Proceed with deletion - await conn.execute( + await transaction.execute( function_jobs_table.delete().where( function_jobs_table.c.uuid == function_job_id ) @@ -820,48 +811,42 @@ async def find_cached_function_jobs( product_name: ProductName, inputs: FunctionInputs, ) -> list[RegisteredFunctionJobDB] | None: - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[FunctionsApiAccessRights.READ_FUNCTION_JOBS], - ) + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + await check_user_api_access_rights( + app, + connection=conn, + user_id=user_id, + product_name=product_name, + api_access_rights=[FunctionsApiAccessRights.READ_FUNCTION_JOBS], + ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: - result = await conn.stream( + jobs: list[RegisteredFunctionJobDB] = [] + async for row in await conn.stream( function_jobs_table.select().where( function_jobs_table.c.function_uuid == function_id, cast(function_jobs_table.c.inputs, Text) == json.dumps(inputs), - ), - ) - rows = await result.all() - - if rows is None or len(rows) == 0: - return None - - jobs = [] - for row in rows: - job = RegisteredFunctionJobDB.model_validate(dict(row)) - try: - await check_user_permissions( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - object_id=job.uuid, - object_type="function_job", - permissions=["read"], ) - except FunctionJobReadAccessDeniedError: - continue + ): + job = RegisteredFunctionJobDB.model_validate(row) + try: + await check_user_permissions( + app, + connection=conn, + user_id=user_id, + product_name=product_name, + object_id=job.uuid, + object_type="function_job", + permissions=["read"], + ) + except FunctionJobReadAccessDeniedError: + continue - jobs.append(job) + jobs.append(job) - if len(jobs) > 0: - return jobs + if len(jobs) > 0: + return jobs - return None + return None async def get_function_job_collection( @@ -872,30 +857,30 @@ async def get_function_job_collection( product_name: ProductName, function_job_collection_id: FunctionID, ) -> tuple[RegisteredFunctionJobCollectionDB, list[FunctionJobID]]: - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[FunctionsApiAccessRights.READ_FUNCTION_JOB_COLLECTIONS], - ) - await check_user_permissions( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - object_id=function_job_collection_id, - object_type="function_job_collection", - permissions=["read"], - ) + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + await check_user_api_access_rights( + app, + connection=conn, + user_id=user_id, + product_name=product_name, + api_access_rights=[FunctionsApiAccessRights.READ_FUNCTION_JOB_COLLECTIONS], + ) + await check_user_permissions( + app, + connection=conn, + user_id=user_id, + product_name=product_name, + object_id=function_job_collection_id, + object_type="function_job_collection", + permissions=["read"], + ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: - result = await conn.stream( + result = await conn.execute( function_job_collections_table.select().where( function_job_collections_table.c.uuid == function_job_collection_id ) ) - row = await result.one_or_none() + row = result.one_or_none() if row is None: raise FunctionJobCollectionIDNotFoundError( @@ -903,19 +888,17 @@ async def get_function_job_collection( ) # Retrieve associated job ids from the join table - job_result = await conn.stream( - function_job_collections_to_function_jobs_table.select().where( - function_job_collections_to_function_jobs_table.c.function_job_collection_uuid - == row["uuid"] + job_ids = [ + job_row.function_job_uuid + async for job_row in await conn.stream( + function_job_collections_to_function_jobs_table.select().where( + function_job_collections_to_function_jobs_table.c.function_job_collection_uuid + == row.uuid + ) ) - ) - job_rows = await job_result.all() - - job_ids = ( - [job_row["function_job_uuid"] for job_row in job_rows] if job_rows else [] - ) + ] - job_collection = RegisteredFunctionJobCollectionDB.model_validate(dict(row)) + job_collection = RegisteredFunctionJobCollectionDB.model_validate(row) return job_collection, job_ids @@ -928,45 +911,45 @@ async def delete_function_job_collection( product_name: ProductName, function_job_collection_id: FunctionID, ) -> None: - await check_user_api_access_rights( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - api_access_rights=[ - FunctionsApiAccessRights.READ_FUNCTION_JOB_COLLECTIONS, - FunctionsApiAccessRights.WRITE_FUNCTION_JOB_COLLECTIONS, - ], - ) - await check_user_permissions( - app, - connection=connection, - user_id=user_id, - product_name=product_name, - object_id=function_job_collection_id, - object_type="function_job_collection", - permissions=["write"], - ) + async with transaction_context(get_asyncpg_engine(app), connection) as transaction: + await check_user_api_access_rights( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + api_access_rights=[ + FunctionsApiAccessRights.READ_FUNCTION_JOB_COLLECTIONS, + FunctionsApiAccessRights.WRITE_FUNCTION_JOB_COLLECTIONS, + ], + ) + await check_user_permissions( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + object_id=function_job_collection_id, + object_type="function_job_collection", + permissions=["write"], + ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: # Check if the function job collection exists - result = await conn.stream( + result = await transaction.execute( function_job_collections_table.select().where( function_job_collections_table.c.uuid == function_job_collection_id ) ) - row = await result.one_or_none() + row = result.one_or_none() if row is None: raise FunctionJobCollectionIDNotFoundError( function_job_collection_id=function_job_collection_id ) # Proceed with deletion - await conn.execute( + await transaction.execute( function_job_collections_table.delete().where( function_job_collections_table.c.uuid == function_job_collection_id ) ) - await conn.execute( + await transaction.execute( function_job_collections_to_function_jobs_table.delete().where( function_job_collections_to_function_jobs_table.c.function_job_collection_uuid == function_job_collection_id @@ -986,7 +969,6 @@ async def set_group_permissions( write: bool | None = None, execute: bool | None = None, ) -> None: - access_rights_table = None field_name = None if object_type == "function": @@ -1002,20 +984,20 @@ async def set_group_permissions( assert access_rights_table is not None # nosec assert field_name is not None # nosec - async with transaction_context(get_asyncpg_engine(app), connection) as conn: + async with transaction_context(get_asyncpg_engine(app), connection) as transaction: for object_id in object_ids: # Check if the group already has access rights for the function - result = await conn.stream( + result = await transaction.execute( access_rights_table.select().where( getattr(access_rights_table.c, field_name) == object_id, access_rights_table.c.group_id == group_id, ) ) - row = await result.one_or_none() + row = result.one_or_none() if row is None: # Insert new access rights if the group does not have any - await conn.execute( + await transaction.execute( access_rights_table.insert().values( **{field_name: object_id}, group_id=group_id, @@ -1033,7 +1015,7 @@ async def set_group_permissions( "execute": execute if execute is not None else row["execute"], } - await conn.execute( + await transaction.execute( access_rights_table.update() .where( getattr(access_rights_table.c, field_name) == object_id, @@ -1050,37 +1032,47 @@ async def get_user_api_access_rights( user_id: UserID, product_name: ProductName, ) -> FunctionUserApiAccessRights: - async with transaction_context(get_asyncpg_engine(app), connection) as conn: + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: user_groups = await list_all_user_groups_ids(app, user_id=user_id) - result = await conn.stream( + # Initialize combined permissions with False values + combined_permissions = FunctionUserApiAccessRights( + user_id=user_id, + read_functions=False, + write_functions=False, + execute_functions=False, + read_function_jobs=False, + write_function_jobs=False, + execute_function_jobs=False, + read_function_job_collections=False, + write_function_job_collections=False, + execute_function_job_collections=False, + ) + + # Process each row only once and combine permissions + async for row in await conn.stream( funcapi_api_access_rights_table.select().where( funcapi_api_access_rights_table.c.group_id.in_(user_groups), funcapi_api_access_rights_table.c.product_name == product_name, ) - ) - rows = await result.all() - if not rows: - return FunctionUserApiAccessRights(user_id=user_id) - combined_permissions = { - "read_functions": any(row["read_functions"] for row in rows), - "write_functions": any(row["write_functions"] for row in rows), - "execute_functions": any(row["execute_functions"] for row in rows), - "read_function_jobs": any(row["read_function_jobs"] for row in rows), - "write_function_jobs": any(row["write_function_jobs"] for row in rows), - "execute_function_jobs": any(row["execute_function_jobs"] for row in rows), - "read_function_job_collections": any( - row["read_function_job_collections"] for row in rows - ), - "write_function_job_collections": any( - row["write_function_job_collections"] for row in rows - ), - "execute_function_job_collections": any( - row["execute_function_job_collections"] for row in rows - ), - "user_id": user_id, - } - return FunctionUserApiAccessRights.model_validate(combined_permissions) + ): + combined_permissions.read_functions |= row.read_functions + combined_permissions.write_functions |= row.write_functions + combined_permissions.execute_functions |= row.execute_functions + combined_permissions.read_function_jobs |= row.read_function_jobs + combined_permissions.write_function_jobs |= row.write_function_jobs + combined_permissions.execute_function_jobs |= row.execute_function_jobs + combined_permissions.read_function_job_collections |= ( + row.read_function_job_collections + ) + combined_permissions.write_function_job_collections |= ( + row.write_function_job_collections + ) + combined_permissions.execute_function_job_collections |= ( + row.execute_function_job_collections + ) + + return combined_permissions async def get_user_permissions( @@ -1092,52 +1084,49 @@ async def get_user_permissions( object_id: UUID, object_type: Literal["function", "function_job", "function_job_collection"], ) -> FunctionAccessRightsDB | None: + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + await check_exists( + app, + conn, + object_id=object_id, + object_type=object_type, + ) - await check_exists( - app, - object_id=object_id, - object_type=object_type, - ) - - access_rights_table = None - cols = None - if object_type == "function": - access_rights_table = functions_access_rights_table - cols = _FUNCTIONS_ACCESS_RIGHTS_TABLE_COLS - elif object_type == "function_job": - access_rights_table = function_jobs_access_rights_table - cols = _FUNCTION_JOBS_ACCESS_RIGHTS_TABLE_COLS - elif object_type == "function_job_collection": - access_rights_table = function_job_collections_access_rights_table - cols = _FUNCTION_JOB_COLLECTIONS_ACCESS_RIGHTS_TABLE_COLS - assert access_rights_table is not None # nosec + access_rights_table = None + cols = None + if object_type == "function": + access_rights_table = functions_access_rights_table + cols = _FUNCTIONS_ACCESS_RIGHTS_TABLE_COLS + elif object_type == "function_job": + access_rights_table = function_jobs_access_rights_table + cols = _FUNCTION_JOBS_ACCESS_RIGHTS_TABLE_COLS + elif object_type == "function_job_collection": + access_rights_table = function_job_collections_access_rights_table + cols = _FUNCTION_JOB_COLLECTIONS_ACCESS_RIGHTS_TABLE_COLS + assert access_rights_table is not None # nosec - async with transaction_context(get_asyncpg_engine(app), connection) as conn: user_groups = await list_all_user_groups_ids(app, user_id=user_id) - # Combine permissions for all groups the user belongs to - result = await conn.stream( + # Initialize combined permissions with False values + combined_permissions = FunctionAccessRightsDB( + read=False, write=False, execute=False + ) + + # Process each row only once and combine permissions + async for row in await conn.stream( access_rights_table.select() - .with_only_columns(cols) + .with_only_columns(*cols) .where( getattr(access_rights_table.c, f"{object_type}_uuid") == object_id, access_rights_table.c.product_name == product_name, access_rights_table.c.group_id.in_(user_groups), ) - ) - rows = await result.all() - - if not rows: - return None - - # Combine permissions across all rows - combined_permissions = { - "read": any(row["read"] for row in rows), - "write": any(row["write"] for row in rows), - "execute": any(row["execute"] for row in rows), - } + ): + combined_permissions.read |= row.read + combined_permissions.write |= row.write + combined_permissions.execute |= row.execute - return FunctionAccessRightsDB.model_validate(combined_permissions) + return combined_permissions async def check_exists( @@ -1168,11 +1157,11 @@ async def check_exists( function_job_collection_id=object_id ) - async with transaction_context(get_asyncpg_engine(app), connection) as conn: - result = await conn.stream( + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + result = await conn.execute( main_table.select().where(main_table.c.uuid == object_id) ) - row = await result.one_or_none() + row = result.one_or_none() if row is None: raise error diff --git a/services/web/server/tests/unit/with_dbs/04/functions_rpc/conftest.py b/services/web/server/tests/unit/with_dbs/04/functions_rpc/conftest.py index c9484cb224ae..3649ba419a27 100644 --- a/services/web/server/tests/unit/with_dbs/04/functions_rpc/conftest.py +++ b/services/web/server/tests/unit/with_dbs/04/functions_rpc/conftest.py @@ -9,6 +9,7 @@ from uuid import uuid4 import pytest +import sqlalchemy from aiohttp.test_utils import TestClient from models_library.api_schemas_webserver.functions import ( Function, @@ -188,7 +189,7 @@ async def add_user_function_api_access_rights( async with asyncpg_engine.begin() as conn: for group_id in (logged_user["primary_gid"], other_logged_user["primary_gid"]): await conn.execute( - funcapi_api_access_rights_table.delete( # type: ignore[union-attr] + sqlalchemy.delete(funcapi_api_access_rights_table).where( funcapi_api_access_rights_table.c.group_id == group_id ) ) diff --git a/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_function_job_collections_controller_rpc.py b/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_function_job_collections_controller_rpc.py index c5e0f3b4d6a0..0d5245d44c18 100644 --- a/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_function_job_collections_controller_rpc.py +++ b/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_function_job_collections_controller_rpc.py @@ -17,6 +17,7 @@ from models_library.functions_errors import ( FunctionJobCollectionReadAccessDeniedError, FunctionJobCollectionsReadApiAccessDeniedError, + FunctionJobCollectionWriteAccessDeniedError, FunctionJobIDNotFoundError, ) from models_library.products import ProductName @@ -137,7 +138,7 @@ async def test_function_job_collection( ) # Attempt to delete the function job collection by another user - with pytest.raises(FunctionJobCollectionReadAccessDeniedError): + with pytest.raises(FunctionJobCollectionWriteAccessDeniedError): await functions_rpc.delete_function_job_collection( rabbitmq_rpc_client=rpc_client, function_job_collection_id=registered_collection.uid, diff --git a/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_function_jobs_controller_rpc.py b/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_function_jobs_controller_rpc.py index a3d0bf046cec..4c564d395147 100644 --- a/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_function_jobs_controller_rpc.py +++ b/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_function_jobs_controller_rpc.py @@ -15,6 +15,7 @@ FunctionJobIDNotFoundError, FunctionJobReadAccessDeniedError, FunctionJobsReadApiAccessDeniedError, + FunctionJobWriteAccessDeniedError, ) from models_library.products import ProductName from pytest_simcore.helpers.webserver_login import UserInfoDict @@ -104,7 +105,7 @@ async def test_register_get_delete_function_job( product_name="this_is_not_osparc", ) - with pytest.raises(FunctionJobReadAccessDeniedError): + with pytest.raises(FunctionJobWriteAccessDeniedError): # Attempt to delete the function job by another user await functions_rpc.delete_function_job( rabbitmq_rpc_client=rpc_client, @@ -301,7 +302,6 @@ async def test_find_cached_function_jobs( mock_function: ProjectFunction, clean_functions: None, ): - # Register the function first registered_function = await functions_rpc.register_function( rabbitmq_rpc_client=rpc_client, diff --git a/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_functions_controller_rpc.py b/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_functions_controller_rpc.py index 99c9855caee6..2e60cfac6771 100644 --- a/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_functions_controller_rpc.py +++ b/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_functions_controller_rpc.py @@ -19,6 +19,7 @@ FunctionIDNotFoundError, FunctionReadAccessDeniedError, FunctionsReadApiAccessDeniedError, + FunctionWriteAccessDeniedError, ) from models_library.products import ProductName from pytest_simcore.helpers.webserver_login import UserInfoDict @@ -88,7 +89,7 @@ async def test_register_get_delete_function( product_name=osparc_product_name, ) - with pytest.raises(FunctionReadAccessDeniedError): + with pytest.raises(FunctionWriteAccessDeniedError): # Attempt to delete the function by another user await functions_rpc.delete_function( rabbitmq_rpc_client=rpc_client, @@ -374,7 +375,7 @@ async def test_update_function_title( # Update the function's title by other user updated_title = "Updated Function Title by Other User" registered_function.title = updated_title - with pytest.raises(FunctionReadAccessDeniedError): + with pytest.raises(FunctionWriteAccessDeniedError): updated_function = await functions_rpc.update_function_title( rabbitmq_rpc_client=rpc_client, function_id=registered_function.uid,