Skip to content

Commit 120de89

Browse files
add yet another test
1 parent 73a95ec commit 120de89

File tree

2 files changed

+105
-20
lines changed

2 files changed

+105
-20
lines changed

packages/simcore-sdk/src/simcore_sdk/node_ports_common/dbmanager.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,35 @@ async def _get_node_from_db(
5151
return node
5252

5353

54+
async def _update_comp_run_snapshot_tasks_if_computational(
55+
engine: AsyncEngine,
56+
connection: AsyncConnection,
57+
project_id: str,
58+
node_uuid: str,
59+
node_configuration: dict,
60+
) -> None:
61+
"""
62+
Updates comp_run_snapshot_tasks table for computational nodes.
63+
"""
64+
node = await _get_node_from_db(project_id, node_uuid, connection)
65+
if node.node_class == NodeClass.COMPUTATIONAL.value:
66+
_latest_run_id = await get_latest_run_id_for_project(
67+
engine, connection, project_id=project_id
68+
)
69+
await update_for_run_id_and_node_id(
70+
engine,
71+
connection,
72+
run_id=_latest_run_id,
73+
node_id=node_uuid,
74+
data={
75+
"schema": node_configuration["schema"],
76+
"inputs": node_configuration["inputs"],
77+
"outputs": node_configuration["outputs"],
78+
"run_hash": node_configuration.get("run_hash"),
79+
},
80+
)
81+
82+
5483
class DBContextManager:
5584
def __init__(self, db_engine: AsyncEngine | None = None) -> None:
5685
self._db_engine: AsyncEngine | None = db_engine
@@ -113,26 +142,9 @@ async def write_ports_configuration(
113142
)
114143

115144
# 2. Update comp_run_snapshot_tasks table only if the node is computational
116-
node = await _get_node_from_db(project_id, node_uuid, connection)
117-
if node.node_class == NodeClass.COMPUTATIONAL.value:
118-
# 2.1 Get latest run id for the project
119-
_latest_run_id = await get_latest_run_id_for_project(
120-
engine, connection, project_id=project_id
121-
)
122-
123-
# 2.2 Update comp_run_snapshot_tasks table
124-
await update_for_run_id_and_node_id(
125-
engine,
126-
connection,
127-
run_id=_latest_run_id,
128-
node_id=node_uuid,
129-
data={
130-
"schema": node_configuration["schema"],
131-
"inputs": node_configuration["inputs"],
132-
"outputs": node_configuration["outputs"],
133-
"run_hash": node_configuration.get("run_hash"),
134-
},
135-
)
145+
await _update_comp_run_snapshot_tasks_if_computational(
146+
engine, connection, project_id, node_uuid, node_configuration
147+
)
136148

137149
async def get_ports_configuration_from_node_uuid(
138150
self, project_id: str, node_uuid: str
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import types
2+
from unittest.mock import AsyncMock
3+
4+
from simcore_sdk.node_ports_common import dbmanager
5+
6+
7+
async def test_update_comp_run_snapshot_tasks_if_computational(monkeypatch):
8+
engine = AsyncMock()
9+
connection = AsyncMock()
10+
project_id = "project-1"
11+
node_uuid = "node-1"
12+
node_configuration = {
13+
"schema": {"foo": "bar"},
14+
"inputs": {"a": 1},
15+
"outputs": {"b": 2},
16+
"run_hash": "hash123",
17+
}
18+
node = types.SimpleNamespace(node_class="COMPUTATIONAL")
19+
20+
get_node_mock = AsyncMock(return_value=node)
21+
get_latest_run_id_mock = AsyncMock(return_value="run-1")
22+
update_mock = AsyncMock()
23+
24+
monkeypatch.setattr(dbmanager, "_get_node_from_db", get_node_mock)
25+
monkeypatch.setattr(
26+
dbmanager, "get_latest_run_id_for_project", get_latest_run_id_mock
27+
)
28+
monkeypatch.setattr(dbmanager, "update_for_run_id_and_node_id", update_mock)
29+
30+
await dbmanager._update_comp_run_snapshot_tasks_if_computational(
31+
engine, connection, project_id, node_uuid, node_configuration
32+
)
33+
34+
get_node_mock.assert_awaited_once_with(project_id, node_uuid, connection)
35+
get_latest_run_id_mock.assert_awaited_once_with(
36+
engine, connection, project_id=project_id
37+
)
38+
update_mock.assert_awaited_once()
39+
_, kwargs = update_mock.call_args
40+
assert kwargs["run_id"] == "run-1"
41+
assert kwargs["node_id"] == node_uuid
42+
assert kwargs["data"]["schema"] == node_configuration["schema"]
43+
44+
45+
async def test_update_comp_run_snapshot_tasks_if_not_computational(monkeypatch):
46+
engine = AsyncMock()
47+
connection = AsyncMock()
48+
project_id = "project-2"
49+
node_uuid = "node-2"
50+
node_configuration = {
51+
"schema": {},
52+
"inputs": {},
53+
"outputs": {},
54+
}
55+
node = types.SimpleNamespace(node_class="ITERATIVE")
56+
57+
get_node_mock = AsyncMock(return_value=node)
58+
get_latest_run_id_mock = AsyncMock()
59+
update_mock = AsyncMock()
60+
61+
monkeypatch.setattr(dbmanager, "_get_node_from_db", get_node_mock)
62+
monkeypatch.setattr(
63+
dbmanager, "get_latest_run_id_for_project", get_latest_run_id_mock
64+
)
65+
monkeypatch.setattr(dbmanager, "update_for_run_id_and_node_id", update_mock)
66+
67+
await dbmanager._update_comp_run_snapshot_tasks_if_computational(
68+
engine, connection, project_id, node_uuid, node_configuration
69+
)
70+
71+
get_node_mock.assert_awaited_once_with(project_id, node_uuid, connection)
72+
get_latest_run_id_mock.assert_not_awaited()
73+
update_mock.assert_not_awaited()

0 commit comments

Comments
 (0)