2121from timeit import default_timer
2222from typing import Any , AsyncGenerator , Optional
2323
24+ import uuid
25+
2426from neo4j_graphrag .utils .logging import prettify
2527
2628try :
3941from neo4j_graphrag .experimental .pipeline .notification import (
4042 Event ,
4143 EventCallbackProtocol ,
42- EventType ,
43- PipelineEvent ,
44+ EventNotifier ,
4445)
4546from neo4j_graphrag .experimental .pipeline .orchestrator import Orchestrator
4647from neo4j_graphrag .experimental .pipeline .pipeline_graph import (
@@ -103,7 +104,7 @@ async def run(
103104 res = await self .execute (context , inputs )
104105 end_time = default_timer ()
105106 logger .debug (
106- f"TASK FINISHED { self .name } in { end_time - start_time } res={ prettify (res )} "
107+ f"TASK FINISHED { self .name } in { round ( end_time - start_time , 2 ) } s res={ prettify (res )} "
107108 )
108109 return res
109110
@@ -124,7 +125,6 @@ def __init__(
124125 ) -> None :
125126 super ().__init__ ()
126127 self .store = store or InMemoryStore ()
127- self .callbacks = [callback ] if callback else []
128128 self .final_results = InMemoryStore ()
129129 self .is_validated = False
130130 self .param_mapping : dict [str , dict [str , dict [str , str ]]] = defaultdict (dict )
@@ -139,6 +139,7 @@ def __init__(
139139 }
140140 """
141141 self .missing_inputs : dict [str , list [str ]] = defaultdict ()
142+ self .event_notifier = EventNotifier ([callback ] if callback else [])
142143
143144 @classmethod
144145 def from_template (
@@ -507,14 +508,13 @@ async def stream(
507508 """
508509 # Create queue for events
509510 event_queue : asyncio .Queue [Event ] = asyncio .Queue ()
510- run_id = None
511511
512512 async def event_stream (event : Event ) -> None :
513513 # Put event in queue for streaming
514514 await event_queue .put (event )
515515
516516 # Add event streaming callback
517- self .callbacks . append (event_stream )
517+ self .event_notifier . add_callback (event_stream )
518518
519519 event_queue_getter_task = None
520520 try :
@@ -542,39 +542,48 @@ async def event_stream(event: Event) -> None:
542542 # we are sure to get an Event here, since this is the only
543543 # thing we put in the queue, but mypy still complains
544544 event = event_future .result ()
545- run_id = getattr (event , "run_id" , None )
546545 yield event # type: ignore
547546
548547 if exc := run_task .exception ():
549- yield PipelineEvent (
550- event_type = EventType .PIPELINE_FAILED ,
551- # run_id is null if pipeline fails before even starting
552- # ie during pipeline validation
553- run_id = run_id or "" ,
554- message = str (exc ),
555- )
556548 if raise_exception :
557549 raise exc
558550
559551 finally :
560552 # Restore original callback
561- self .callbacks . remove (event_stream )
553+ self .event_notifier . remove_callback (event_stream )
562554 if event_queue_getter_task and not event_queue_getter_task .done ():
563555 event_queue_getter_task .cancel ()
564556
565557 async def run (self , data : dict [str , Any ]) -> PipelineResult :
566- logger .debug ("PIPELINE START" )
567558 start_time = default_timer ()
568- self .invalidate ()
569- self .validate_input_data (data )
570- orchestrator = Orchestrator (self )
571- logger .debug (f"PIPELINE ORCHESTRATOR: { orchestrator .run_id } " )
572- await orchestrator .run (data )
559+ run_id = str (uuid .uuid4 ())
560+ logger .debug (f"PIPELINE START with { run_id = } " )
561+ try :
562+ res = await self ._run (run_id , data )
563+ except Exception as e :
564+ await self .event_notifier .notify_pipeline_failed (
565+ run_id ,
566+ message = f"Pipeline failed with error { e } " ,
567+ )
568+ raise e
573569 end_time = default_timer ()
574570 logger .debug (
575- f"PIPELINE FINISHED { orchestrator . run_id } in { end_time - start_time } s"
571+ f"PIPELINE FINISHED { run_id } in { round ( end_time - start_time , 2 ) } s"
576572 )
577- return PipelineResult (
573+ return res
574+
575+ async def _run (self , run_id : str , data : dict [str , Any ]) -> PipelineResult :
576+ await self .event_notifier .notify_pipeline_started (run_id , data )
577+ self .invalidate ()
578+ self .validate_input_data (data )
579+ orchestrator = Orchestrator (self , run_id )
580+ await orchestrator .run (data )
581+ result = PipelineResult (
578582 run_id = orchestrator .run_id ,
579583 result = await self .get_final_results (orchestrator .run_id ),
580584 )
585+ await self .event_notifier .notify_pipeline_finished (
586+ run_id ,
587+ await self .get_final_results (run_id ),
588+ )
589+ return result
0 commit comments