1919logger = logging .getLogger (__name__ )
2020
2121
22+ async def monitor (function , interval , timeout , * args , ** kwargs ):
23+ """Monitor the function until it returns `True` or the timeout is reached."""
24+ import time
25+
26+ start_time = time .time ()
27+ while True :
28+ if asyncio .iscoroutinefunction (function ):
29+ result = await function (* args , ** kwargs )
30+ else :
31+ result = function (* args , ** kwargs )
32+ if result :
33+ break
34+ if time .time () - start_time > timeout :
35+ raise TimeoutError (f"Function monitoring timed out after { timeout } seconds" )
36+ await asyncio .sleep (interval )
37+
38+
2239async def task_run_job (process : Process , * args , ** kwargs ) -> Any :
2340 """Run the *async* user function and return results or a structured error."""
2441 node = process .node
@@ -41,10 +58,41 @@ async def task_run_job(process: Process, *args, **kwargs) -> Any:
4158 }
4259
4360
61+ async def task_run_monitor_job (process : Process , * args , ** kwargs ) -> Any :
62+ """Run the *async* user function and return results or a structured error."""
63+ node = process .node
64+
65+ inputs = dict (process .inputs .function_inputs or {})
66+ deserializers = node .base .attributes .get (ATTR_DESERIALIZERS , {})
67+ inputs = deserialize_to_raw_python_data (inputs , deserializers = deserializers )
68+
69+ try :
70+ logger .info (f"scheduled request to run the function<{ node .pk } >" )
71+ results = await monitor (process .func , interval = process .inputs .interval , timeout = process .inputs .timeout , ** inputs )
72+ logger .info (f"running function<{ node .pk } > successful" )
73+ return {"__ok__" : True , "results" : results }
74+ except TimeoutError as exception :
75+ logger .warning (f"running function<{ node .pk } > timed out" )
76+ return {
77+ "__error__" : "ERROR_TIMEOUT" ,
78+ "exception" : str (exception ),
79+ "traceback" : traceback .format_exc (),
80+ }
81+ except Exception as exception :
82+ logger .warning (f"running function<{ node .pk } > failed" )
83+ return {
84+ "__error__" : "ERROR_FUNCTION_EXECUTION_FAILED" ,
85+ "exception" : str (exception ),
86+ "traceback" : traceback .format_exc (),
87+ }
88+
89+
4490@plumpy .persistence .auto_persist ("msg" , "data" )
4591class Waiting (plumpy .process_states .Waiting ):
4692 """The waiting state for the `PyFunction` process."""
4793
94+ task_run_job = staticmethod (task_run_job )
95+
4896 def __init__ (
4997 self ,
5098 process : Process ,
@@ -69,23 +117,17 @@ async def execute(self) -> plumpy.process_states.State:
69117 node = self .process .node
70118 node .set_process_status ("Running async function" )
71119 try :
72- payload = await self ._launch_task (task_run_job , self .process )
120+ payload = await self ._launch_task (self . task_run_job , self .process )
73121
74122 # Convert structured payloads into the next state or an ExitCode
75123 if payload .get ("__ok__" ):
76124 return self .parse (payload ["results" ])
77125 elif payload .get ("__error__" ):
78126 err = payload ["__error__" ]
79- if err == "ERROR_DESERIALIZE_INPUTS_FAILED" :
80- exit_code = self .process .exit_codes .ERROR_DESERIALIZE_INPUTS_FAILED .format (
81- exception = payload .get ("exception" , "" ),
82- traceback = payload .get ("traceback" , "" ),
83- )
84- else :
85- exit_code = self .process .exit_codes .ERROR_FUNCTION_EXECUTION_FAILED .format (
86- exception = payload .get ("exception" , "" ),
87- traceback = payload .get ("traceback" , "" ),
88- )
127+ exit_code = getattr (self .process .exit_codes , err ).format (
128+ exception = payload .get ("exception" , "" ),
129+ traceback = payload .get ("traceback" , "" ),
130+ )
89131 # Jump straight to FINISHED by scheduling parse with the error ExitCode
90132 # We reuse the Running->parse path so the process finishes uniformly.
91133 return self .create_state (ProcessState .RUNNING , self .process .parse , {"__exit_code__" : exit_code })
@@ -124,3 +166,9 @@ def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ig
124166 self ._killing = plumpy .futures .Future ()
125167 return self ._killing
126168 return None
169+
170+
171+ class MonitorWaiting (Waiting ):
172+ """A version of Waiting that can be monitored."""
173+
174+ task_run_job = staticmethod (task_run_monitor_job )
0 commit comments