Skip to content

Commit 3f92f41

Browse files
authored
really upsert in comp_task (#2126)
1 parent 4c0db62 commit 3f92f41

File tree

1 file changed

+96
-58
lines changed
  • services/director-v2/src/simcore_service_director_v2/modules/db/repositories

1 file changed

+96
-58
lines changed

services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks.py

Lines changed: 96 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List, Optional
44

55
import sqlalchemy as sa
6+
from aiopg.sa.result import RowProxy
67
from models_library.projects import ProjectID
78
from models_library.projects_nodes import Node
89
from models_library.projects_nodes_io import NodeID
@@ -57,6 +58,69 @@ def _get_fake_service_details(
5758
return None
5859

5960

61+
async def _generate_tasks_list_from_project(
62+
project: ProjectAtDB,
63+
director_client: DirectorV0Client,
64+
published_nodes: List[NodeID],
65+
) -> List[CompTaskAtDB]:
66+
67+
list_comp_tasks = []
68+
69+
for internal_id, node_id in enumerate(project.workbench, 1):
70+
node: Node = project.workbench[node_id]
71+
72+
service_key_version = ServiceKeyVersion(
73+
key=node.key,
74+
version=node.version,
75+
)
76+
node_class = to_node_class(service_key_version.key)
77+
node_details: ServiceDockerData = None
78+
node_extras: ServiceExtras = None
79+
if node_class == NodeClass.FRONTEND:
80+
node_details = _get_fake_service_details(service_key_version)
81+
else:
82+
node_details = await director_client.get_service_details(
83+
service_key_version
84+
)
85+
node_extras: ServiceExtras = await director_client.get_service_extras(
86+
service_key_version
87+
)
88+
if not node_details:
89+
continue
90+
91+
requires_mpi = False
92+
requires_gpu = False
93+
if node_extras:
94+
requires_gpu = NodeRequirement.GPU in node_extras.node_requirements
95+
requires_mpi = NodeRequirement.MPI in node_extras.node_requirements
96+
image = Image(
97+
name=service_key_version.key,
98+
tag=service_key_version.version,
99+
requires_gpu=requires_gpu,
100+
requires_mpi=requires_mpi,
101+
)
102+
103+
task_db = CompTaskAtDB(
104+
project_id=project.uuid,
105+
node_id=node_id,
106+
schema=NodeSchema(inputs=node_details.inputs, outputs=node_details.outputs),
107+
inputs=node.inputs,
108+
outputs=node.outputs,
109+
image=image,
110+
submit=datetime.utcnow(),
111+
state=(
112+
RunningState.PUBLISHED
113+
if node_id in published_nodes and node_class == NodeClass.COMPUTATIONAL
114+
else RunningState.NOT_STARTED
115+
),
116+
internal_id=internal_id,
117+
node_class=node_class,
118+
)
119+
120+
list_comp_tasks.append(task_db)
121+
return list_comp_tasks
122+
123+
60124
class CompTasksRepository(BaseRepository):
61125
@log_decorator(logger=logger)
62126
async def get_comp_tasks(
@@ -84,71 +148,45 @@ async def _sequentially_upsert_tasks_from_project(
84148
published_nodes: List[NodeID],
85149
str_project_uuid: str,
86150
) -> None:
87-
# start by removing the old tasks if they exist
88-
await self.connection.execute(
89-
sa.delete(comp_tasks).where(comp_tasks.c.project_id == str(project.uuid))
90-
)
91-
# create the tasks
92-
93-
for internal_id, node_id in enumerate(project.workbench, 1):
94-
node: Node = project.workbench[node_id]
95151

96-
service_key_version = ServiceKeyVersion(
97-
key=node.key,
98-
version=node.version,
152+
# NOTE: really do an upsert here because of issue https://github.com/ITISFoundation/osparc-simcore/issues/2125
153+
list_of_comp_tasks_in_project = await _generate_tasks_list_from_project(
154+
project, director_client, published_nodes
155+
)
156+
# get current tasks
157+
result: RowProxy = await self.connection.execute(
158+
sa.select([comp_tasks.c.node_id]).where(
159+
comp_tasks.c.project_id == str(project.uuid)
99160
)
100-
node_class = to_node_class(service_key_version.key)
101-
node_details: ServiceDockerData = None
102-
node_extras: ServiceExtras = None
103-
if node_class == NodeClass.FRONTEND:
104-
node_details = _get_fake_service_details(service_key_version)
105-
else:
106-
node_details = await director_client.get_service_details(
107-
service_key_version
108-
)
109-
node_extras: ServiceExtras = await director_client.get_service_extras(
110-
service_key_version
161+
)
162+
# remove the tasks that were removed from project workbench
163+
node_ids_to_delete = [
164+
t.node_id
165+
for t in await result.fetchall()
166+
if t.node_id not in project.workbench
167+
]
168+
for node_id in node_ids_to_delete:
169+
await self.connection.execute(
170+
sa.delete(comp_tasks).where(
171+
(comp_tasks.c.project_id == str(project.uuid))
172+
& (comp_tasks.c.node_id == node_id)
111173
)
112-
if not node_details:
113-
continue
114-
115-
requires_mpi = False
116-
requires_gpu = False
117-
if node_extras:
118-
requires_gpu = NodeRequirement.GPU in node_extras.node_requirements
119-
requires_mpi = NodeRequirement.MPI in node_extras.node_requirements
120-
image = Image(
121-
name=service_key_version.key,
122-
tag=service_key_version.version,
123-
requires_gpu=requires_gpu,
124-
requires_mpi=requires_mpi,
125174
)
126175

127-
task_db = CompTaskAtDB(
128-
project_id=project.uuid,
129-
node_id=node_id,
130-
schema=NodeSchema(
131-
inputs=node_details.inputs, outputs=node_details.outputs
132-
),
133-
inputs=node.inputs,
134-
outputs=node.outputs,
135-
image=image,
136-
submit=datetime.utcnow(),
137-
state=(
138-
RunningState.PUBLISHED
139-
if node_id in published_nodes
140-
and node_class == NodeClass.COMPUTATIONAL
141-
else RunningState.NOT_STARTED
142-
),
143-
internal_id=internal_id,
144-
node_class=node_class,
145-
)
176+
# insert or update the remaining tasks
177+
# NOTE: comp_tasks DB only trigger a notification to the webserver if an UPDATE on comp_tasks.outputs or comp_tasks.state is done
178+
for comp_task_db in list_of_comp_tasks_in_project:
146179

147-
await self.connection.execute(
148-
insert(comp_tasks).values(
149-
**task_db.dict(by_alias=True, exclude_unset=True)
150-
)
180+
insert_stmt = insert(comp_tasks).values(
181+
**comp_task_db.dict(by_alias=True, exclude_unset=True)
182+
)
183+
on_update_stmt = insert_stmt.on_conflict_do_update(
184+
index_elements=[comp_tasks.c.project_id, comp_tasks.c.node_id],
185+
set_=comp_task_db.dict(
186+
by_alias=True, exclude_unset=True, exclude={"outputs", "state"}
187+
),
151188
)
189+
await self.connection.execute(on_update_stmt)
152190

153191
@log_decorator(logger=logger)
154192
async def upsert_tasks_from_project(

0 commit comments

Comments
 (0)