1919from models_library .projects_nodes_io import NodeID
2020from models_library .projects_state import RunningState
2121from pydantic .types import PositiveInt
22+ from simcore_postgres_database .models .comp_tasks import comp_tasks
2223from simcore_postgres_database .webserver_models import DB_CHANNEL_NAME , projects
2324from sqlalchemy .sql import select
2425
@@ -61,59 +62,75 @@ async def _update_project_state(
6162@dataclass (frozen = True )
6263class _CompTaskNotificationPayload :
6364 action : str
64- data : dict
6565 changes : dict
6666 table : str
67+ task_id : str | None = None
68+ project_id : str | None = None
69+ node_id : str | None = None
6770
6871
6972async def _handle_db_notification (
7073 app : web .Application , payload : _CompTaskNotificationPayload , conn : SAConnection
7174) -> None :
72- task_data = payload .data
75+ project_uuid = payload .project_id
76+ node_uuid = payload .node_id
7377 task_changes = payload .changes
7478
75- project_uuid = task_data .get ("project_id" , None )
76- node_uuid = task_data .get ("node_id" , None )
7779 if any (x is None for x in [project_uuid , node_uuid ]):
7880 _logger .warning (
7981 "comp_tasks row is corrupted. TIP: please check DB entry containing '%s'" ,
80- f"{ task_data = } " ,
82+ f"{ payload = } " ,
8183 )
8284 return
8385
8486 assert project_uuid # nosec
8587 assert node_uuid # nosec
8688
8789 try :
88- # NOTE: we need someone with the rights to modify that project. the owner is one.
89- # find the user(s) linked to that project
9090 the_project_owner = await _get_project_owner (conn , project_uuid )
9191
92- if any (f in task_changes for f in ["outputs" , "run_hash" ]):
93- new_outputs = task_data .get ("outputs" , {})
94- new_run_hash = task_data .get ("run_hash" , None )
92+ # Fetch the latest comp_tasks row for this node/project
93+ result = await conn .execute (
94+ select (comp_tasks ).where (
95+ (comp_tasks .c .project_id == project_uuid )
96+ & (comp_tasks .c .node_id == node_uuid )
97+ )
98+ )
99+ row = await result .first ()
100+ if not row :
101+ _logger .warning (
102+ "No comp_tasks row found for project_id=%s node_id=%s" ,
103+ project_uuid ,
104+ node_uuid ,
105+ )
106+ return
95107
108+ if any (f in task_changes for f in ["outputs" , "run_hash" ]):
109+ new_outputs = row .outputs if hasattr (row , "outputs" ) else {}
110+ new_run_hash = row .run_hash if hasattr (row , "run_hash" ) else None
111+ node_errors = row .errors if hasattr (row , "errors" ) else None
96112 await update_node_outputs (
97113 app ,
98114 the_project_owner ,
99115 ProjectID (project_uuid ),
100116 NodeID (node_uuid ),
101117 new_outputs ,
102118 new_run_hash ,
103- node_errors = task_data . get ( "errors" , None ) ,
119+ node_errors = node_errors ,
104120 ui_changed_keys = None ,
105121 )
106122
107123 if "state" in task_changes :
108- new_state = convert_state_from_db (task_data ["state" ])
109- await _update_project_state (
110- app ,
111- the_project_owner ,
112- ProjectID (project_uuid ),
113- NodeID (node_uuid ),
114- new_state ,
115- node_errors = task_data .get ("errors" , None ),
116- )
124+ new_state = row .state if hasattr (row , "state" ) else None
125+ if new_state is not None :
126+ await _update_project_state (
127+ app ,
128+ the_project_owner ,
129+ ProjectID (project_uuid ),
130+ NodeID (node_uuid ),
131+ convert_state_from_db (new_state ),
132+ node_errors = row .errors if hasattr (row , "errors" ) else None ,
133+ )
117134
118135 except exceptions .ProjectNotFoundError as exc :
119136 _logger .warning (
@@ -151,7 +168,6 @@ async def _listen(app: web.Application, db_engine: Engine) -> NoReturn:
151168 await asyncio .sleep (_LISTENING_TASK_BASE_SLEEPING_TIME_S )
152169 continue
153170 notification = conn .connection .notifies .get_nowait ()
154- # get the data and the info on what changed
155171 payload = _CompTaskNotificationPayload (** json_loads (notification .payload ))
156172 _logger .debug ("received update from database: %s" , f"{ payload = } " )
157173 await _handle_db_notification (app , payload , conn )
@@ -161,12 +177,10 @@ async def _comp_tasks_listening_task(app: web.Application) -> None:
161177 _logger .info ("starting comp_task db listening task..." )
162178 while True :
163179 try :
164- # create a special connection here
165180 db_engine = get_database_engine (app )
166181 _logger .info ("listening to comp_task events..." )
167182 await _listen (app , db_engine )
168- except asyncio .CancelledError : # noqa: PERF203
169- # we are closing the app..
183+ except asyncio .CancelledError :
170184 _logger .info ("cancelled comp_tasks events" )
171185 raise
172186 except Exception : # pylint: disable=broad-except
0 commit comments