Skip to content

Commit 4fbb625

Browse files
fix tests
1 parent 48fac73 commit 4fbb625

File tree

7 files changed

+87
-13
lines changed

7 files changed

+87
-13
lines changed

packages/models-library/src/models_library/projects_nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class Node(BaseModel):
165165
] = DEFAULT_FACTORY
166166

167167
inputs_required: Annotated[
168-
list[InputID],
168+
list[InputID] | None,
169169
Field(
170170
default_factory=list,
171171
description="Defines inputs that are required in order to run the service",

services/director-v2/src/simcore_service_director_v2/api/routes/computations.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
ComputationGet,
2828
ComputationStop,
2929
)
30-
from models_library.projects import ProjectAtDB, ProjectID
30+
from models_library.projects import NodesDict, ProjectAtDB, ProjectID
3131
from models_library.projects_nodes_io import NodeID
3232
from models_library.projects_state import RunningState
3333
from models_library.services import ServiceKeyVersion
@@ -38,6 +38,9 @@
3838
from servicelib.logging_utils import log_decorator
3939
from servicelib.rabbitmq import RabbitMQRPCClient
4040
from simcore_postgres_database.utils_projects_metadata import DBProjectNotFoundError
41+
from simcore_service_director_v2.modules.db.repositories.projects_nodes import (
42+
ProjectsNodesRepository,
43+
)
4144
from starlette import status
4245
from starlette.requests import Request
4346
from tenacity import retry
@@ -196,6 +199,7 @@ async def _try_start_pipeline(
196199
complete_dag: nx.DiGraph,
197200
minimal_dag: nx.DiGraph,
198201
project: ProjectAtDB,
202+
workbench: NodesDict,
199203
users_repo: UsersRepository,
200204
projects_metadata_repo: ProjectsMetadataRepository,
201205
) -> None:
@@ -226,7 +230,7 @@ async def _try_start_pipeline(
226230
run_metadata=RunMetadataDict(
227231
node_id_names_map={
228232
NodeID(node_idstr): node_data.label
229-
for node_idstr, node_data in project.workbench.items()
233+
for node_idstr, node_data in workbench.items()
230234
},
231235
product_name=computation.product_name,
232236
project_name=project.name,
@@ -273,6 +277,9 @@ async def create_computation( # noqa: PLR0913 # pylint: disable=too-many-positi
273277
project_repo: Annotated[
274278
ProjectsRepository, Depends(get_repository(ProjectsRepository))
275279
],
280+
project_nodes_repo: Annotated[
281+
ProjectsNodesRepository, Depends(get_repository(ProjectsNodesRepository))
282+
],
276283
comp_pipelines_repo: Annotated[
277284
CompPipelinesRepository, Depends(get_repository(CompPipelinesRepository))
278285
],
@@ -302,8 +309,12 @@ async def create_computation( # noqa: PLR0913 # pylint: disable=too-many-positi
302309
# check if current state allow to modify the computation
303310
await _check_pipeline_not_running_or_raise_409(comp_tasks_repo, computation)
304311

312+
workbench: NodesDict = await project_nodes_repo.get_nodes(
313+
computation.project_id
314+
)
315+
305316
# create the complete DAG graph
306-
complete_dag = create_complete_dag(project.workbench)
317+
complete_dag = create_complete_dag(workbench)
307318
# find the minimal viable graph to be run
308319
minimal_computational_dag: nx.DiGraph = (
309320
await create_minimal_computational_graph_based_on_selection(
@@ -330,6 +341,7 @@ async def create_computation( # noqa: PLR0913 # pylint: disable=too-many-positi
330341
]
331342
comp_tasks = await comp_tasks_repo.upsert_tasks_from_project(
332343
project=project,
344+
workbench=workbench,
333345
catalog_client=catalog_client,
334346
published_nodes=min_computation_nodes if computation.start_pipeline else [],
335347
user_id=computation.user_id,
@@ -347,6 +359,7 @@ async def create_computation( # noqa: PLR0913 # pylint: disable=too-many-positi
347359
complete_dag=complete_dag,
348360
minimal_dag=minimal_computational_dag,
349361
project=project,
362+
workbench=workbench,
350363
users_repo=users_repo,
351364
projects_metadata_repo=projects_metadata_repo,
352365
)

services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ....catalog import CatalogClient
2727
from ...tables import NodeClass, StateType, comp_tasks
2828
from .._base import BaseRepository
29+
from ..projects_nodes import NodesDict
2930
from . import _utils
3031

3132
_logger = logging.getLogger(__name__)
@@ -89,6 +90,7 @@ async def upsert_tasks_from_project(
8990
self,
9091
*,
9192
project: ProjectAtDB,
93+
workbench: NodesDict,
9294
catalog_client: CatalogClient,
9395
published_nodes: list[NodeID],
9496
user_id: UserID,
@@ -103,6 +105,7 @@ async def upsert_tasks_from_project(
103105
CompTaskAtDB
104106
] = await _utils.generate_tasks_list_from_project(
105107
project=project,
108+
workbench=workbench,
106109
catalog_client=catalog_client,
107110
published_nodes=published_nodes,
108111
user_id=user_id,
@@ -121,7 +124,7 @@ async def upsert_tasks_from_project(
121124
# remove the tasks that were removed from project workbench
122125
if all_nodes := await result.fetchall():
123126
node_ids_to_delete = [
124-
t.node_id for t in all_nodes if t.node_id not in project.workbench
127+
t.node_id for t in all_nodes if t.node_id not in workbench
125128
]
126129
for node_id in node_ids_to_delete:
127130
await conn.execute(

services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from ....catalog import CatalogClient
6161
from ....comp_scheduler._utils import COMPLETED_STATES
6262
from ...tables import NodeClass
63+
from ..projects_nodes import NodesDict
6364

6465
_logger = logging.getLogger(__name__)
6566

@@ -331,6 +332,7 @@ def _by_type_name(ec2: EC2InstanceTypeGet) -> bool:
331332
async def generate_tasks_list_from_project(
332333
*,
333334
project: ProjectAtDB,
335+
workbench: NodesDict,
334336
catalog_client: CatalogClient,
335337
published_nodes: list[NodeID],
336338
user_id: UserID,
@@ -346,7 +348,7 @@ async def generate_tasks_list_from_project(
346348
ServiceKeyVersion(
347349
key=node.key, version=node.version
348350
) # the service key version is frozen
349-
for node in project.workbench.values()
351+
for node in workbench.values()
350352
}
351353

352354
key_version_to_node_infos = {
@@ -359,8 +361,8 @@ async def generate_tasks_list_from_project(
359361
for key_version in unique_service_key_versions
360362
}
361363

362-
for internal_id, node_id in enumerate(project.workbench, 1):
363-
node: Node = project.workbench[node_id]
364+
for internal_id, node_id in enumerate(workbench, start=1):
365+
node: Node = workbench[node_id]
364366
node_key_version = ServiceKeyVersion(key=node.key, version=node.version)
365367
node_details, node_extras, node_labels = key_version_to_node_infos.get(
366368
node_key_version,
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import logging
2+
3+
import sqlalchemy as sa
4+
from models_library.projects import NodesDict, ProjectID
5+
from models_library.projects_nodes import Node
6+
from simcore_postgres_database.utils_projects_nodes import ProjectNode
7+
8+
from ..tables import projects_nodes
9+
from ._base import BaseRepository
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class ProjectsNodesRepository(BaseRepository):
15+
async def get_nodes(self, project_uuid: ProjectID) -> NodesDict:
16+
nodes_dict = {}
17+
async with self.db_engine.acquire() as conn:
18+
rows = await (
19+
await conn.execute(
20+
sa.select(projects_nodes).where(
21+
projects_nodes.c.project_uuid == f"{project_uuid}"
22+
)
23+
)
24+
).fetchall()
25+
26+
for row in rows:
27+
nodes_dict[f"{row.node_id}"] = Node.model_validate(
28+
ProjectNode.model_validate(row, from_attributes=True).model_dump(
29+
exclude={
30+
"node_id",
31+
"required_resources",
32+
"created",
33+
"modified",
34+
},
35+
exclude_none=True,
36+
exclude_unset=True,
37+
)
38+
)
39+
40+
return nodes_dict
41+
42+
# async def is_node_present_in_workbench(
43+
# self, project_id: ProjectID, node_uuid: NodeID
44+
# ) -> bool:
45+
# try:
46+
# project = await self.get_project(project_id)
47+
# return f"{node_uuid}" in project.workbench
48+
# except ProjectNotFoundError:
49+
# return False
50+
51+
# async def get_project_id_from_node(self, node_id: NodeID) -> ProjectID:
52+
# async with self.db_engine.acquire() as conn:
53+
# return await ProjectNodesRepo.get_project_id_from_node_id(
54+
# conn, node_id=node_id
55+
# )

services/director-v2/src/simcore_service_director_v2/modules/db/tables.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88
from simcore_postgres_database.models.projects import ProjectType, projects
99
from simcore_postgres_database.models.projects_networks import projects_networks
10+
from simcore_postgres_database.models.projects_nodes import projects_nodes
1011

1112
__all__ = [
1213
"comp_pipeline",
@@ -15,6 +16,7 @@
1516
"groups_extra_properties",
1617
"NodeClass",
1718
"projects_networks",
19+
"projects_nodes",
1820
"projects",
1921
"ProjectType",
2022
"StateType",

services/director-v2/src/simcore_service_director_v2/utils/dags.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def create_complete_dag(workbench: NodesDict) -> nx.DiGraph:
4242
)
4343
if node.input_nodes:
4444
for input_node_id in node.input_nodes:
45-
predecessor_node = workbench.get(NodeIDStr(input_node_id))
46-
if predecessor_node:
45+
if f"{input_node_id}" in workbench: # predecessor node
4746
dag_graph.add_edge(str(input_node_id), node_id)
4847

4948
return dag_graph
@@ -188,7 +187,7 @@ def compute_pipeline_started_timestamp(
188187
if not pipeline_dag.nodes:
189188
return None
190189
node_id_to_comp_task: dict[NodeIDStr, CompTaskAtDB] = {
191-
NodeIDStr(f"{task.node_id}"): task for task in comp_tasks
190+
f"{task.node_id}": task for task in comp_tasks
192191
}
193192
TOMORROW = arrow.utcnow().shift(days=1).datetime
194193
pipeline_started_at: datetime.datetime | None = min(
@@ -206,7 +205,7 @@ def compute_pipeline_stopped_timestamp(
206205
if not pipeline_dag.nodes:
207206
return None
208207
node_id_to_comp_task: dict[NodeIDStr, CompTaskAtDB] = {
209-
NodeIDStr(f"{task.node_id}"): task for task in comp_tasks
208+
f"{task.node_id}": task for task in comp_tasks
210209
}
211210
TOMORROW = arrow.utcnow().shift(days=1).datetime
212211
pipeline_stopped_at: datetime.datetime | None = max(
@@ -227,7 +226,7 @@ async def compute_pipeline_details(
227226

228227
# NOTE: the latest progress is available in comp_tasks only
229228
node_id_to_comp_task: dict[NodeIDStr, CompTaskAtDB] = {
230-
NodeIDStr(f"{task.node_id}"): task for task in comp_tasks
229+
f"{task.node_id}": task for task in comp_tasks
231230
}
232231
pipeline_progress = None
233232
if len(pipeline_dag.nodes) > 0:

0 commit comments

Comments
 (0)