|
3 | 3 | from typing import List, Optional |
4 | 4 |
|
5 | 5 | import sqlalchemy as sa |
| 6 | +from aiopg.sa.result import RowProxy |
6 | 7 | from models_library.projects import ProjectID |
7 | 8 | from models_library.projects_nodes import Node |
8 | 9 | from models_library.projects_nodes_io import NodeID |
@@ -57,6 +58,69 @@ def _get_fake_service_details( |
57 | 58 | return None |
58 | 59 |
|
59 | 60 |
|
| 61 | +async def _generate_tasks_list_from_project( |
| 62 | + project: ProjectAtDB, |
| 63 | + director_client: DirectorV0Client, |
| 64 | + published_nodes: List[NodeID], |
| 65 | +) -> List[CompTaskAtDB]: |
| 66 | + |
| 67 | + list_comp_tasks = [] |
| 68 | + |
| 69 | + for internal_id, node_id in enumerate(project.workbench, 1): |
| 70 | + node: Node = project.workbench[node_id] |
| 71 | + |
| 72 | + service_key_version = ServiceKeyVersion( |
| 73 | + key=node.key, |
| 74 | + version=node.version, |
| 75 | + ) |
| 76 | + node_class = to_node_class(service_key_version.key) |
| 77 | + node_details: ServiceDockerData = None |
| 78 | + node_extras: ServiceExtras = None |
| 79 | + if node_class == NodeClass.FRONTEND: |
| 80 | + node_details = _get_fake_service_details(service_key_version) |
| 81 | + else: |
| 82 | + node_details = await director_client.get_service_details( |
| 83 | + service_key_version |
| 84 | + ) |
| 85 | + node_extras: ServiceExtras = await director_client.get_service_extras( |
| 86 | + service_key_version |
| 87 | + ) |
| 88 | + if not node_details: |
| 89 | + continue |
| 90 | + |
| 91 | + requires_mpi = False |
| 92 | + requires_gpu = False |
| 93 | + if node_extras: |
| 94 | + requires_gpu = NodeRequirement.GPU in node_extras.node_requirements |
| 95 | + requires_mpi = NodeRequirement.MPI in node_extras.node_requirements |
| 96 | + image = Image( |
| 97 | + name=service_key_version.key, |
| 98 | + tag=service_key_version.version, |
| 99 | + requires_gpu=requires_gpu, |
| 100 | + requires_mpi=requires_mpi, |
| 101 | + ) |
| 102 | + |
| 103 | + task_db = CompTaskAtDB( |
| 104 | + project_id=project.uuid, |
| 105 | + node_id=node_id, |
| 106 | + schema=NodeSchema(inputs=node_details.inputs, outputs=node_details.outputs), |
| 107 | + inputs=node.inputs, |
| 108 | + outputs=node.outputs, |
| 109 | + image=image, |
| 110 | + submit=datetime.utcnow(), |
| 111 | + state=( |
| 112 | + RunningState.PUBLISHED |
| 113 | + if node_id in published_nodes and node_class == NodeClass.COMPUTATIONAL |
| 114 | + else RunningState.NOT_STARTED |
| 115 | + ), |
| 116 | + internal_id=internal_id, |
| 117 | + node_class=node_class, |
| 118 | + ) |
| 119 | + |
| 120 | + list_comp_tasks.append(task_db) |
| 121 | + return list_comp_tasks |
| 122 | + |
| 123 | + |
60 | 124 | class CompTasksRepository(BaseRepository): |
61 | 125 | @log_decorator(logger=logger) |
62 | 126 | async def get_comp_tasks( |
@@ -84,71 +148,45 @@ async def _sequentially_upsert_tasks_from_project( |
84 | 148 | published_nodes: List[NodeID], |
85 | 149 | str_project_uuid: str, |
86 | 150 | ) -> None: |
87 | | - # start by removing the old tasks if they exist |
88 | | - await self.connection.execute( |
89 | | - sa.delete(comp_tasks).where(comp_tasks.c.project_id == str(project.uuid)) |
90 | | - ) |
91 | | - # create the tasks |
92 | | - |
93 | | - for internal_id, node_id in enumerate(project.workbench, 1): |
94 | | - node: Node = project.workbench[node_id] |
95 | 151 |
|
96 | | - service_key_version = ServiceKeyVersion( |
97 | | - key=node.key, |
98 | | - version=node.version, |
| 152 | + # NOTE: really do an upsert here because of issue https://github.com/ITISFoundation/osparc-simcore/issues/2125 |
| 153 | + list_of_comp_tasks_in_project = await _generate_tasks_list_from_project( |
| 154 | + project, director_client, published_nodes |
| 155 | + ) |
| 156 | + # get current tasks |
| 157 | + result: RowProxy = await self.connection.execute( |
| 158 | + sa.select([comp_tasks.c.node_id]).where( |
| 159 | + comp_tasks.c.project_id == str(project.uuid) |
99 | 160 | ) |
100 | | - node_class = to_node_class(service_key_version.key) |
101 | | - node_details: ServiceDockerData = None |
102 | | - node_extras: ServiceExtras = None |
103 | | - if node_class == NodeClass.FRONTEND: |
104 | | - node_details = _get_fake_service_details(service_key_version) |
105 | | - else: |
106 | | - node_details = await director_client.get_service_details( |
107 | | - service_key_version |
108 | | - ) |
109 | | - node_extras: ServiceExtras = await director_client.get_service_extras( |
110 | | - service_key_version |
| 161 | + ) |
| 162 | + # remove the tasks that were removed from project workbench |
| 163 | + node_ids_to_delete = [ |
| 164 | + t.node_id |
| 165 | + for t in await result.fetchall() |
| 166 | + if t.node_id not in project.workbench |
| 167 | + ] |
| 168 | + for node_id in node_ids_to_delete: |
| 169 | + await self.connection.execute( |
| 170 | + sa.delete(comp_tasks).where( |
| 171 | + (comp_tasks.c.project_id == str(project.uuid)) |
| 172 | + & (comp_tasks.c.node_id == node_id) |
111 | 173 | ) |
112 | | - if not node_details: |
113 | | - continue |
114 | | - |
115 | | - requires_mpi = False |
116 | | - requires_gpu = False |
117 | | - if node_extras: |
118 | | - requires_gpu = NodeRequirement.GPU in node_extras.node_requirements |
119 | | - requires_mpi = NodeRequirement.MPI in node_extras.node_requirements |
120 | | - image = Image( |
121 | | - name=service_key_version.key, |
122 | | - tag=service_key_version.version, |
123 | | - requires_gpu=requires_gpu, |
124 | | - requires_mpi=requires_mpi, |
125 | 174 | ) |
126 | 175 |
|
127 | | - task_db = CompTaskAtDB( |
128 | | - project_id=project.uuid, |
129 | | - node_id=node_id, |
130 | | - schema=NodeSchema( |
131 | | - inputs=node_details.inputs, outputs=node_details.outputs |
132 | | - ), |
133 | | - inputs=node.inputs, |
134 | | - outputs=node.outputs, |
135 | | - image=image, |
136 | | - submit=datetime.utcnow(), |
137 | | - state=( |
138 | | - RunningState.PUBLISHED |
139 | | - if node_id in published_nodes |
140 | | - and node_class == NodeClass.COMPUTATIONAL |
141 | | - else RunningState.NOT_STARTED |
142 | | - ), |
143 | | - internal_id=internal_id, |
144 | | - node_class=node_class, |
145 | | - ) |
| 176 | + # insert or update the remaining tasks |
| 177 | + # NOTE: comp_tasks DB only trigger a notification to the webserver if an UPDATE on comp_tasks.outputs or comp_tasks.state is done |
| 178 | + for comp_task_db in list_of_comp_tasks_in_project: |
146 | 179 |
|
147 | | - await self.connection.execute( |
148 | | - insert(comp_tasks).values( |
149 | | - **task_db.dict(by_alias=True, exclude_unset=True) |
150 | | - ) |
| 180 | + insert_stmt = insert(comp_tasks).values( |
| 181 | + **comp_task_db.dict(by_alias=True, exclude_unset=True) |
| 182 | + ) |
| 183 | + on_update_stmt = insert_stmt.on_conflict_do_update( |
| 184 | + index_elements=[comp_tasks.c.project_id, comp_tasks.c.node_id], |
| 185 | + set_=comp_task_db.dict( |
| 186 | + by_alias=True, exclude_unset=True, exclude={"outputs", "state"} |
| 187 | + ), |
151 | 188 | ) |
| 189 | + await self.connection.execute(on_update_stmt) |
152 | 190 |
|
153 | 191 | @log_decorator(logger=logger) |
154 | 192 | async def upsert_tasks_from_project( |
|
0 commit comments