Skip to content

Commit 5894e69

Browse files
committed
ongoing
1 parent a295c9a commit 5894e69

File tree

1 file changed

+38
-24
lines changed

1 file changed

+38
-24
lines changed

services/web/server/src/simcore_service_webserver/db_listener/_db_comp_tasks_listening_task.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from models_library.projects_nodes_io import NodeID
2020
from models_library.projects_state import RunningState
2121
from pydantic.types import PositiveInt
22+
from simcore_postgres_database.models.comp_tasks import comp_tasks
2223
from simcore_postgres_database.webserver_models import DB_CHANNEL_NAME, projects
2324
from sqlalchemy.sql import select
2425

@@ -61,59 +62,75 @@ async def _update_project_state(
6162
@dataclass(frozen=True)
6263
class _CompTaskNotificationPayload:
6364
action: str
64-
data: dict
6565
changes: dict
6666
table: str
67+
task_id: str | None = None
68+
project_id: str | None = None
69+
node_id: str | None = None
6770

6871

6972
async def _handle_db_notification(
7073
app: web.Application, payload: _CompTaskNotificationPayload, conn: SAConnection
7174
) -> None:
72-
task_data = payload.data
75+
project_uuid = payload.project_id
76+
node_uuid = payload.node_id
7377
task_changes = payload.changes
7478

75-
project_uuid = task_data.get("project_id", None)
76-
node_uuid = task_data.get("node_id", None)
7779
if any(x is None for x in [project_uuid, node_uuid]):
7880
_logger.warning(
7981
"comp_tasks row is corrupted. TIP: please check DB entry containing '%s'",
80-
f"{task_data=}",
82+
f"{payload=}",
8183
)
8284
return
8385

8486
assert project_uuid # nosec
8587
assert node_uuid # nosec
8688

8789
try:
88-
# NOTE: we need someone with the rights to modify that project. the owner is one.
89-
# find the user(s) linked to that project
9090
the_project_owner = await _get_project_owner(conn, project_uuid)
9191

92-
if any(f in task_changes for f in ["outputs", "run_hash"]):
93-
new_outputs = task_data.get("outputs", {})
94-
new_run_hash = task_data.get("run_hash", None)
92+
# Fetch the latest comp_tasks row for this node/project
93+
result = await conn.execute(
94+
select(comp_tasks).where(
95+
(comp_tasks.c.project_id == project_uuid)
96+
& (comp_tasks.c.node_id == node_uuid)
97+
)
98+
)
99+
row = await result.first()
100+
if not row:
101+
_logger.warning(
102+
"No comp_tasks row found for project_id=%s node_id=%s",
103+
project_uuid,
104+
node_uuid,
105+
)
106+
return
95107

108+
if any(f in task_changes for f in ["outputs", "run_hash"]):
109+
new_outputs = row.outputs if hasattr(row, "outputs") else {}
110+
new_run_hash = row.run_hash if hasattr(row, "run_hash") else None
111+
node_errors = row.errors if hasattr(row, "errors") else None
96112
await update_node_outputs(
97113
app,
98114
the_project_owner,
99115
ProjectID(project_uuid),
100116
NodeID(node_uuid),
101117
new_outputs,
102118
new_run_hash,
103-
node_errors=task_data.get("errors", None),
119+
node_errors=node_errors,
104120
ui_changed_keys=None,
105121
)
106122

107123
if "state" in task_changes:
108-
new_state = convert_state_from_db(task_data["state"])
109-
await _update_project_state(
110-
app,
111-
the_project_owner,
112-
ProjectID(project_uuid),
113-
NodeID(node_uuid),
114-
new_state,
115-
node_errors=task_data.get("errors", None),
116-
)
124+
new_state = row.state if hasattr(row, "state") else None
125+
if new_state is not None:
126+
await _update_project_state(
127+
app,
128+
the_project_owner,
129+
ProjectID(project_uuid),
130+
NodeID(node_uuid),
131+
convert_state_from_db(new_state),
132+
node_errors=row.errors if hasattr(row, "errors") else None,
133+
)
117134

118135
except exceptions.ProjectNotFoundError as exc:
119136
_logger.warning(
@@ -151,7 +168,6 @@ async def _listen(app: web.Application, db_engine: Engine) -> NoReturn:
151168
await asyncio.sleep(_LISTENING_TASK_BASE_SLEEPING_TIME_S)
152169
continue
153170
notification = conn.connection.notifies.get_nowait()
154-
# get the data and the info on what changed
155171
payload = _CompTaskNotificationPayload(**json_loads(notification.payload))
156172
_logger.debug("received update from database: %s", f"{payload=}")
157173
await _handle_db_notification(app, payload, conn)
@@ -161,12 +177,10 @@ async def _comp_tasks_listening_task(app: web.Application) -> None:
161177
_logger.info("starting comp_task db listening task...")
162178
while True:
163179
try:
164-
# create a special connection here
165180
db_engine = get_database_engine(app)
166181
_logger.info("listening to comp_task events...")
167182
await _listen(app, db_engine)
168-
except asyncio.CancelledError: # noqa: PERF203
169-
# we are closing the app..
183+
except asyncio.CancelledError:
170184
_logger.info("cancelled comp_tasks events")
171185
raise
172186
except Exception: # pylint: disable=broad-except

0 commit comments

Comments
 (0)