|
2 | 2 |
|
3 | 3 | import sqlalchemy as sa |
4 | 4 | from models_library.projects import ProjectAtDB, ProjectID |
5 | | -from models_library.projects_nodes import Node |
6 | 5 | from models_library.projects_nodes_io import NodeID |
7 | | -from simcore_postgres_database.utils_projects_nodes import ProjectNode, ProjectNodesRepo |
| 6 | +from simcore_postgres_database.utils_projects_nodes import ( |
| 7 | + ProjectNodesRepo, |
| 8 | + make_workbench_subquery, |
| 9 | +) |
8 | 10 | from simcore_postgres_database.utils_repos import pass_or_acquire_connection |
9 | 11 |
|
10 | 12 | from ....core.errors import ProjectNotFoundError |
|
14 | 16 | logger = logging.getLogger(__name__) |
15 | 17 |
|
16 | 18 |
|
17 | | -def _project_node_to_node(project_node: ProjectNode) -> Node: |
18 | | - """Converts a ProjectNode from the database to a Node model for the API. |
19 | | -
|
20 | | - Handles field mapping and excludes database-specific fields that are not |
21 | | - part of the Node model. |
22 | | - """ |
23 | | - node_data = project_node.model_dump_as_node() |
24 | | - return Node.model_validate(node_data) |
25 | | - |
26 | | - |
27 | 19 | class ProjectsRepository(BaseRepository): |
28 | 20 | async def get_project(self, project_id: ProjectID) -> ProjectAtDB: |
| 21 | + workbench_subquery = make_workbench_subquery() |
| 22 | + |
29 | 23 | async with self.db_engine.connect() as conn: |
30 | | - row = ( |
31 | | - await conn.execute( |
32 | | - sa.select(projects).where(projects.c.uuid == str(project_id)) |
| 24 | + query = ( |
| 25 | + sa.select( |
| 26 | + projects, |
| 27 | + sa.func.coalesce( |
| 28 | + workbench_subquery.c.workbench, sa.text("'{}'::json") |
| 29 | + ).label("workbench"), |
| 30 | + ) |
| 31 | + .select_from( |
| 32 | + projects.outerjoin( |
| 33 | + workbench_subquery, |
| 34 | + projects.c.uuid == workbench_subquery.c.project_uuid, |
| 35 | + ) |
33 | 36 | ) |
34 | | - ).one_or_none() |
| 37 | + .where(projects.c.uuid == str(project_id)) |
| 38 | + ) |
| 39 | + result = await conn.execute(query) |
| 40 | + row = result.one_or_none() |
35 | 41 | if not row: |
36 | 42 | raise ProjectNotFoundError(project_id=project_id) |
37 | | - |
38 | | - repo = ProjectNodesRepo(project_uuid=project_id) |
39 | | - nodes = await repo.list(conn) |
40 | | - |
41 | | - project_workbench = { |
42 | | - f"{node.node_id}": _project_node_to_node(node) for node in nodes |
43 | | - } |
44 | | - data = {**row._asdict(), "workbench": project_workbench} |
45 | | - return ProjectAtDB.model_validate(data) |
| 43 | + return ProjectAtDB.model_validate(row) |
46 | 44 |
|
47 | 45 | async def is_node_present_in_workbench( |
48 | 46 | self, project_id: ProjectID, node_uuid: NodeID |
49 | 47 | ) -> bool: |
50 | 48 | async with pass_or_acquire_connection(self.db_engine) as conn: |
51 | | - result = await conn.execute( |
52 | | - sa.select(projects_nodes.c.project_node_id).where( |
| 49 | + stmt = ( |
| 50 | + sa.select(sa.literal(1)) |
| 51 | + .where( |
53 | 52 | projects_nodes.c.project_uuid == str(project_id), |
54 | 53 | projects_nodes.c.node_id == str(node_uuid), |
55 | 54 | ) |
| 55 | + .limit(1) |
56 | 56 | ) |
57 | | - project_node = result.one_or_none() |
58 | | - return project_node is not None |
| 57 | + |
| 58 | + result = await conn.execute(stmt) |
| 59 | + return result.scalar_one_or_none() is not None |
59 | 60 |
|
60 | 61 | async def get_project_id_from_node(self, node_id: NodeID) -> ProjectID: |
61 | 62 | async with self.db_engine.connect() as conn: |
|
0 commit comments