Skip to content

Commit 3934b7b

Browse files
committed
Add monitor function
1 parent a212e7a commit 3934b7b

File tree

6 files changed

+180
-20
lines changed

6 files changed

+180
-20
lines changed

src/aiida_pythonjob/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@
44

55
from node_graph import socket_spec as spec
66

7-
from .calculations import PyFunction, PythonJob
7+
from .calculations import MonitorPyFunction, PyFunction, PythonJob
88
from .decorator import pyfunction
9-
from .launch import prepare_pyfunction_inputs, prepare_pythonjob_inputs
9+
from .launch import prepare_monitor_function_inputs, prepare_pyfunction_inputs, prepare_pythonjob_inputs
1010
from .parsers import PythonJobParser
1111

1212
__all__ = (
13+
"MonitorPyFunction",
1314
"PyFunction",
1415
"PythonJob",
1516
"PythonJobParser",
17+
"prepare_monitor_function_inputs",
1618
"prepare_pyfunction_inputs",
1719
"prepare_pythonjob_inputs",
1820
"pyfunction",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .pyfunction import PyFunction
1+
from .pyfunction import MonitorPyFunction, PyFunction
22
from .pythonjob import PythonJob
33

4-
__all__ = ("PyFunction", "PythonJob")
4+
__all__ = ("MonitorPyFunction", "PyFunction", "PythonJob")

src/aiida_pythonjob/calculations/pyfunction.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import traceback
67
import typing as t
78

@@ -10,7 +11,8 @@
1011
from aiida.common.lang import override
1112
from aiida.engine import Process, ProcessSpec, ProcessState
1213
from aiida.engine.processes.exit_code import ExitCode
13-
from aiida.orm import CalcFunctionNode
14+
from aiida.orm import CalcFunctionNode, Float
15+
from aiida.orm.nodes.data.base import to_aiida_type
1416
from node_graph.socket_spec import SocketSpec
1517

1618
from aiida_pythonjob.calculations.common import (
@@ -23,7 +25,7 @@
2325
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data
2426
from aiida_pythonjob.parsers.utils import parse_outputs
2527

26-
from .tasks import Waiting
28+
from .tasks import MonitorWaiting, Waiting
2729

2830
__all__ = ("PyFunction",)
2931

@@ -32,6 +34,7 @@ class PyFunction(FunctionProcessMixin, Process):
3234
_node_class = CalcFunctionNode
3335
label_template = "{name}"
3436
default_name = "anonymous_function"
37+
_WAITING = Waiting
3538

3639
def __init__(self, *args, **kwargs) -> None:
3740
if kwargs.get("enable_persistence", False):
@@ -79,7 +82,7 @@ def define(cls, spec: ProcessSpec) -> None: # type: ignore[override]
7982
@classmethod
8083
def get_state_classes(cls) -> t.Dict[t.Hashable, t.Type[plumpy.process_states.State]]:
8184
states_map = super().get_state_classes()
82-
states_map[ProcessState.WAITING] = Waiting
85+
states_map[ProcessState.WAITING] = cls._WAITING
8386
return states_map
8487

8588
@override
@@ -96,8 +99,6 @@ def execute(self) -> dict[str, t.Any] | None:
9699

97100
@override
98101
async def run(self) -> t.Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]:
99-
import asyncio
100-
101102
if self.node.exit_status is not None:
102103
return ExitCode(self.node.exit_status, self.node.exit_message)
103104

@@ -148,3 +149,43 @@ def parse(self, results: t.Optional[dict] = None) -> ExitCode:
148149
for name, value in (outputs or {}).items():
149150
self.out(name, value)
150151
return ExitCode()
152+
153+
154+
class MonitorPyFunction(PyFunction):
155+
"""A version of PyFunction that can be monitored."""
156+
157+
default_name = "monitor_function"
158+
_WAITING = MonitorWaiting
159+
160+
@classmethod
161+
def define(cls, spec: ProcessSpec) -> None: # type: ignore[override]
162+
"""Define inputs/outputs and exit codes."""
163+
super().define(spec)
164+
spec.input(
165+
"interval",
166+
valid_type=Float,
167+
default=lambda: Float(5.0),
168+
serializer=to_aiida_type,
169+
help="Polling interval in seconds.",
170+
)
171+
spec.input(
172+
"timeout",
173+
valid_type=Float,
174+
default=lambda: Float(3600.0),
175+
serializer=to_aiida_type,
176+
help="Timeout in seconds.",
177+
)
178+
179+
spec.exit_code(
180+
324,
181+
"ERROR_TIMEOUT",
182+
invalidates_cache=True,
183+
message="Monitor function execution timed out.\n{exception}\n{traceback}",
184+
)
185+
186+
@override
187+
async def run(self) -> t.Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]:
188+
if self.node.exit_status is not None:
189+
return ExitCode(self.node.exit_status, self.node.exit_message)
190+
191+
return plumpy.process_states.Wait(msg="Waiting to run")

src/aiida_pythonjob/calculations/tasks.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@
1919
logger = logging.getLogger(__name__)
2020

2121

22+
async def monitor(function, interval, timeout, *args, **kwargs):
23+
"""Monitor the function until it returns `True` or the timeout is reached."""
24+
import time
25+
26+
start_time = time.time()
27+
while True:
28+
if asyncio.iscoroutinefunction(function):
29+
result = await function(*args, **kwargs)
30+
else:
31+
result = function(*args, **kwargs)
32+
if result:
33+
break
34+
if time.time() - start_time > timeout:
35+
raise TimeoutError(f"Function monitoring timed out after {timeout} seconds")
36+
await asyncio.sleep(interval)
37+
38+
2239
async def task_run_job(process: Process, *args, **kwargs) -> Any:
2340
"""Run the *async* user function and return results or a structured error."""
2441
node = process.node
@@ -41,10 +58,41 @@ async def task_run_job(process: Process, *args, **kwargs) -> Any:
4158
}
4259

4360

61+
async def task_run_monitor_job(process: Process, *args, **kwargs) -> Any:
62+
"""Run the *async* user function and return results or a structured error."""
63+
node = process.node
64+
65+
inputs = dict(process.inputs.function_inputs or {})
66+
deserializers = node.base.attributes.get(ATTR_DESERIALIZERS, {})
67+
inputs = deserialize_to_raw_python_data(inputs, deserializers=deserializers)
68+
69+
try:
70+
logger.info(f"scheduled request to run the function<{node.pk}>")
71+
results = await monitor(process.func, interval=process.inputs.interval, timeout=process.inputs.timeout, **inputs)
72+
logger.info(f"running function<{node.pk}> successful")
73+
return {"__ok__": True, "results": results}
74+
except TimeoutError as exception:
75+
logger.warning(f"running function<{node.pk}> timed out")
76+
return {
77+
"__error__": "ERROR_TIMEOUT",
78+
"exception": str(exception),
79+
"traceback": traceback.format_exc(),
80+
}
81+
except Exception as exception:
82+
logger.warning(f"running function<{node.pk}> failed")
83+
return {
84+
"__error__": "ERROR_FUNCTION_EXECUTION_FAILED",
85+
"exception": str(exception),
86+
"traceback": traceback.format_exc(),
87+
}
88+
89+
4490
@plumpy.persistence.auto_persist("msg", "data")
4591
class Waiting(plumpy.process_states.Waiting):
4692
"""The waiting state for the `PyFunction` process."""
4793

94+
task_run_job = staticmethod(task_run_job)
95+
4896
def __init__(
4997
self,
5098
process: Process,
@@ -69,23 +117,17 @@ async def execute(self) -> plumpy.process_states.State:
69117
node = self.process.node
70118
node.set_process_status("Running async function")
71119
try:
72-
payload = await self._launch_task(task_run_job, self.process)
120+
payload = await self._launch_task(self.task_run_job, self.process)
73121

74122
# Convert structured payloads into the next state or an ExitCode
75123
if payload.get("__ok__"):
76124
return self.parse(payload["results"])
77125
elif payload.get("__error__"):
78126
err = payload["__error__"]
79-
if err == "ERROR_DESERIALIZE_INPUTS_FAILED":
80-
exit_code = self.process.exit_codes.ERROR_DESERIALIZE_INPUTS_FAILED.format(
81-
exception=payload.get("exception", ""),
82-
traceback=payload.get("traceback", ""),
83-
)
84-
else:
85-
exit_code = self.process.exit_codes.ERROR_FUNCTION_EXECUTION_FAILED.format(
86-
exception=payload.get("exception", ""),
87-
traceback=payload.get("traceback", ""),
88-
)
127+
exit_code = getattr(self.process.exit_codes, err).format(
128+
exception=payload.get("exception", ""),
129+
traceback=payload.get("traceback", ""),
130+
)
89131
# Jump straight to FINISHED by scheduling parse with the error ExitCode
90132
# We reuse the Running->parse path so the process finishes uniformly.
91133
return self.create_state(ProcessState.RUNNING, self.process.parse, {"__exit_code__": exit_code})
@@ -124,3 +166,9 @@ def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ig
124166
self._killing = plumpy.futures.Future()
125167
return self._killing
126168
return None
169+
170+
171+
class MonitorWaiting(Waiting):
172+
"""A version of Waiting that can be monitored."""
173+
174+
task_run_job = staticmethod(task_run_monitor_job)

src/aiida_pythonjob/launch.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,39 @@ def prepare_pyfunction_inputs(
260260
if process_label:
261261
inputs["process_label"] = process_label
262262
return inputs
263+
264+
265+
def prepare_monitor_function_inputs(
266+
function: Optional[Callable[..., Any]] = None,
267+
function_inputs: Optional[Dict[str, Any]] = None,
268+
inputs_spec: Optional[type] = None,
269+
outputs_spec: Optional[type] = None,
270+
metadata: Optional[Dict[str, Any]] = None,
271+
process_label: Optional[str] = None,
272+
function_data: dict | None = None,
273+
deserializers: dict | None = None,
274+
serializers: dict | None = None,
275+
register_pickle_by_value: bool = False,
276+
interval: Optional[Union[int, float, orm.Float, orm.Int]] = None,
277+
timeout: Optional[Union[int, float, orm.Float, orm.Int]] = None,
278+
**kwargs: Any,
279+
) -> Dict[str, Any]:
280+
"""
281+
Prepare the inputs for a monitor function (no Code/upload_files).
282+
"""
283+
inputs = prepare_pyfunction_inputs(
284+
function=function,
285+
function_inputs=function_inputs,
286+
inputs_spec=inputs_spec,
287+
outputs_spec=outputs_spec,
288+
metadata=metadata,
289+
process_label=process_label,
290+
function_data=function_data,
291+
deserializers=deserializers,
292+
serializers=serializers,
293+
register_pickle_by_value=register_pickle_by_value,
294+
**kwargs,
295+
)
296+
inputs["interval"] = orm.Float(interval) if interval is not None else orm.Float(10.0)
297+
inputs["timeout"] = orm.Float(timeout) if timeout is not None else orm.Float(3600.0)
298+
return inputs

tests/test_monitor.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import datetime
2+
3+
from aiida.engine import run_get_node
4+
5+
from aiida_pythonjob import MonitorPyFunction, prepare_monitor_function_inputs
6+
7+
8+
def monitor_time(time: datetime.datetime):
9+
return datetime.datetime.now() > time
10+
11+
12+
def test_async_function_runs_and_returns_result():
13+
inputs = prepare_monitor_function_inputs(
14+
monitor_time,
15+
function_inputs={"time": datetime.datetime.now() + datetime.timedelta(seconds=5)},
16+
)
17+
result, node = run_get_node(MonitorPyFunction, **inputs)
18+
assert node.is_finished_ok
19+
assert "result" in result
20+
# The actual monitor function returns None
21+
assert result["result"].value is None
22+
23+
24+
def test_async_function_raises_produces_exit_code():
25+
inputs = prepare_monitor_function_inputs(
26+
monitor_time,
27+
function_inputs={"time": datetime.datetime.now() + datetime.timedelta(seconds=20)},
28+
timeout=5.0,
29+
)
30+
_, node = run_get_node(MonitorPyFunction, **inputs)
31+
assert not node.is_finished_ok
32+
assert node.exit_status == MonitorPyFunction.exit_codes.ERROR_TIMEOUT.status
33+
assert "Monitor function execution timed out." in node.exit_message

0 commit comments

Comments
 (0)