Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 44 additions & 21 deletions docs/gallery/autogen/pyfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def generate_structures(element: str, factors: list) -> dict:

from aiida.engine import submit
import datetime
from aiida_pythonjob import prepare_pyfunction_inputs
from aiida_pythonjob import prepare_pyfunction_inputs, PyFunction


@pyfunction()
Expand All @@ -111,39 +111,62 @@ async def add_async(x, y, time: float):
function_inputs={"x": 2, "y": 3, "time": 2.0},
)

node = submit(add_async, **inputs)
node = submit(PyFunction, **inputs)

# %%
# Here is an example to monitor external events or conditions without blocking.
# %%#
# Killing an async process
# ~~~~~~~~~~~~~~~~~~~~~~~~
# Since async functions run as regular AiiDA processes, they can be controlled and killed
# programmatically. This is useful for managing long-running or stuck tasks.
# You can kill a running async function using the AiiDA command line interface.
#
# .. code-block:: bash
#
# $ verdi process kill <pk>
#
# Monitor external events
# ------------------------
#
# Async functions are particularly useful for monitoring external events or conditions without blocking the AiiDA daemon.
# Here is an example that waits until a specified time.
#


@pyfunction()
async def monitor_time(time: datetime.datetime):
async def monitor_time(time: datetime.datetime, interval: float = 0.5, timeout: float = 60.0):
"""Monitor the current time until it reaches the specified target time."""
import asyncio

# monitor until the specified time
start_time = datetime.datetime.now()
while datetime.datetime.now() < time:
print("Waiting...")
await asyncio.sleep(0.5)
await asyncio.sleep(interval)
if (datetime.datetime.now() - start_time).total_seconds() > timeout:
raise TimeoutError("Monitoring timed out.")


inputs = prepare_pyfunction_inputs(
monitor_time,
function_inputs={"time": datetime.datetime.now() + datetime.timedelta(seconds=5)},
function_inputs={"time": datetime.datetime.now() + datetime.timedelta(seconds=5), "interval": 1.0},
)

node = submit(monitor_time, **inputs)
node = submit(PyFunction, **inputs)
# %%
# For user's convenience, we provide a dedicated ``MonitorFunction`` class that inherits from ``PyFunction``.
# User only need to write normal function, which returns True when the monitoring condition is met.

# %%#
# Killing an async process
# ------------------------
# Since async functions run as regular AiiDA processes, they can be controlled and killed
# programmatically. This is useful for managing long-running or stuck tasks.
# You can kill a running async function using the AiiDA command line interface.
#
# .. code-block:: bash
#
# $ verdi process kill <pk>
#
from aiida_pythonjob import MonitorPyFunction


def monitor_time(time: datetime.datetime):
# return True when the current time is greater than the target time
return datetime.datetime.now() > time


inputs = prepare_pyfunction_inputs(
monitor_time,
function_inputs={"time": datetime.datetime.now() + datetime.timedelta(seconds=5)},
interval=1.0,
timeout=20.0,
)

node = submit(MonitorPyFunction, **inputs)
6 changes: 4 additions & 2 deletions src/aiida_pythonjob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

from node_graph import socket_spec as spec

from .calculations import PyFunction, PythonJob
from .calculations import MonitorPyFunction, PyFunction, PythonJob
from .decorator import pyfunction
from .launch import prepare_pyfunction_inputs, prepare_pythonjob_inputs
from .launch import prepare_monitor_function_inputs, prepare_pyfunction_inputs, prepare_pythonjob_inputs
from .parsers import PythonJobParser

__all__ = (
"MonitorPyFunction",
"PyFunction",
"PythonJob",
"PythonJobParser",
"prepare_monitor_function_inputs",
"prepare_pyfunction_inputs",
"prepare_pythonjob_inputs",
"pyfunction",
Expand Down
4 changes: 2 additions & 2 deletions src/aiida_pythonjob/calculations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .pyfunction import PyFunction
from .pyfunction import MonitorPyFunction, PyFunction
from .pythonjob import PythonJob

__all__ = ("PyFunction", "PythonJob")
__all__ = ("MonitorPyFunction", "PyFunction", "PythonJob")
51 changes: 46 additions & 5 deletions src/aiida_pythonjob/calculations/pyfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import traceback
import typing as t

Expand All @@ -10,7 +11,8 @@
from aiida.common.lang import override
from aiida.engine import Process, ProcessSpec, ProcessState
from aiida.engine.processes.exit_code import ExitCode
from aiida.orm import CalcFunctionNode
from aiida.orm import CalcFunctionNode, Float
from aiida.orm.nodes.data.base import to_aiida_type
from node_graph.socket_spec import SocketSpec

from aiida_pythonjob.calculations.common import (
Expand All @@ -23,7 +25,7 @@
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data
from aiida_pythonjob.parsers.utils import parse_outputs

from .tasks import Waiting
from .tasks import MonitorWaiting, Waiting

__all__ = ("PyFunction",)

Expand All @@ -32,6 +34,7 @@ class PyFunction(FunctionProcessMixin, Process):
_node_class = CalcFunctionNode
label_template = "{name}"
default_name = "anonymous_function"
_WAITING = Waiting

def __init__(self, *args, **kwargs) -> None:
if kwargs.get("enable_persistence", False):
Expand Down Expand Up @@ -79,7 +82,7 @@ def define(cls, spec: ProcessSpec) -> None: # type: ignore[override]
@classmethod
def get_state_classes(cls) -> t.Dict[t.Hashable, t.Type[plumpy.process_states.State]]:
states_map = super().get_state_classes()
states_map[ProcessState.WAITING] = Waiting
states_map[ProcessState.WAITING] = cls._WAITING
return states_map

@override
Expand All @@ -96,8 +99,6 @@ def execute(self) -> dict[str, t.Any] | None:

@override
async def run(self) -> t.Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]:
import asyncio

if self.node.exit_status is not None:
return ExitCode(self.node.exit_status, self.node.exit_message)

Expand Down Expand Up @@ -148,3 +149,43 @@ def parse(self, results: t.Optional[dict] = None) -> ExitCode:
for name, value in (outputs or {}).items():
self.out(name, value)
return ExitCode()


class MonitorPyFunction(PyFunction):
"""A version of PyFunction that can be monitored."""

default_name = "monitor_function"
_WAITING = MonitorWaiting

@classmethod
def define(cls, spec: ProcessSpec) -> None: # type: ignore[override]
"""Define inputs/outputs and exit codes."""
super().define(spec)
spec.input(
"interval",
valid_type=Float,
default=lambda: Float(5.0),
serializer=to_aiida_type,
help="Polling interval in seconds.",
)
spec.input(
"timeout",
valid_type=Float,
default=lambda: Float(3600.0),
serializer=to_aiida_type,
help="Timeout in seconds.",
)

spec.exit_code(
324,
"ERROR_TIMEOUT",
invalidates_cache=True,
message="Monitor function execution timed out.\n{exception}\n{traceback}",
)

@override
async def run(self) -> t.Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]:
if self.node.exit_status is not None:
return ExitCode(self.node.exit_status, self.node.exit_message)

return plumpy.process_states.Wait(msg="Waiting to run")
70 changes: 59 additions & 11 deletions src/aiida_pythonjob/calculations/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@
logger = logging.getLogger(__name__)


async def monitor(function, interval, timeout, *args, **kwargs):
"""Monitor the function until it returns `True` or the timeout is reached."""
import time

start_time = time.time()
while True:
if asyncio.iscoroutinefunction(function):
result = await function(*args, **kwargs)
else:
result = function(*args, **kwargs)
if result:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Function monitoring timed out after {timeout} seconds")
await asyncio.sleep(interval)


async def task_run_job(process: Process, *args, **kwargs) -> Any:
"""Run the *async* user function and return results or a structured error."""
node = process.node
Expand All @@ -41,10 +58,41 @@ async def task_run_job(process: Process, *args, **kwargs) -> Any:
}


async def task_run_monitor_job(process: Process, *args, **kwargs) -> Any:
"""Run the *async* user function and return results or a structured error."""
node = process.node

inputs = dict(process.inputs.function_inputs or {})
deserializers = node.base.attributes.get(ATTR_DESERIALIZERS, {})
inputs = deserialize_to_raw_python_data(inputs, deserializers=deserializers)

try:
logger.info(f"scheduled request to run the function<{node.pk}>")
results = await monitor(process.func, interval=process.inputs.interval, timeout=process.inputs.timeout, **inputs)
logger.info(f"running function<{node.pk}> successful")
return {"__ok__": True, "results": results}
except TimeoutError as exception:
logger.warning(f"running function<{node.pk}> timed out")
return {
"__error__": "ERROR_TIMEOUT",
"exception": str(exception),
"traceback": traceback.format_exc(),
}
except Exception as exception:
logger.warning(f"running function<{node.pk}> failed")
return {
"__error__": "ERROR_FUNCTION_EXECUTION_FAILED",
"exception": str(exception),
"traceback": traceback.format_exc(),
}


@plumpy.persistence.auto_persist("msg", "data")
class Waiting(plumpy.process_states.Waiting):
"""The waiting state for the `PyFunction` process."""

task_run_job = staticmethod(task_run_job)

def __init__(
self,
process: Process,
Expand All @@ -69,23 +117,17 @@ async def execute(self) -> plumpy.process_states.State:
node = self.process.node
node.set_process_status("Running async function")
try:
payload = await self._launch_task(task_run_job, self.process)
payload = await self._launch_task(self.task_run_job, self.process)

# Convert structured payloads into the next state or an ExitCode
if payload.get("__ok__"):
return self.parse(payload["results"])
elif payload.get("__error__"):
err = payload["__error__"]
if err == "ERROR_DESERIALIZE_INPUTS_FAILED":
exit_code = self.process.exit_codes.ERROR_DESERIALIZE_INPUTS_FAILED.format(
exception=payload.get("exception", ""),
traceback=payload.get("traceback", ""),
)
else:
exit_code = self.process.exit_codes.ERROR_FUNCTION_EXECUTION_FAILED.format(
exception=payload.get("exception", ""),
traceback=payload.get("traceback", ""),
)
exit_code = getattr(self.process.exit_codes, err).format(
exception=payload.get("exception", ""),
traceback=payload.get("traceback", ""),
)
# Jump straight to FINISHED by scheduling parse with the error ExitCode
# We reuse the Running->parse path so the process finishes uniformly.
return self.create_state(ProcessState.RUNNING, self.process.parse, {"__exit_code__": exit_code})
Expand Down Expand Up @@ -124,3 +166,9 @@ def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ig
self._killing = plumpy.futures.Future()
return self._killing
return None


class MonitorWaiting(Waiting):
"""A version of Waiting that can be monitored."""

task_run_job = staticmethod(task_run_monitor_job)
36 changes: 36 additions & 0 deletions src/aiida_pythonjob/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,39 @@ def prepare_pyfunction_inputs(
if process_label:
inputs["process_label"] = process_label
return inputs


def prepare_monitor_function_inputs(
function: Optional[Callable[..., Any]] = None,
function_inputs: Optional[Dict[str, Any]] = None,
inputs_spec: Optional[type] = None,
outputs_spec: Optional[type] = None,
metadata: Optional[Dict[str, Any]] = None,
process_label: Optional[str] = None,
function_data: dict | None = None,
deserializers: dict | None = None,
serializers: dict | None = None,
register_pickle_by_value: bool = False,
interval: Optional[Union[int, float, orm.Float, orm.Int]] = None,
timeout: Optional[Union[int, float, orm.Float, orm.Int]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
Prepare the inputs for a monitor function (no Code/upload_files).
"""
inputs = prepare_pyfunction_inputs(
function=function,
function_inputs=function_inputs,
inputs_spec=inputs_spec,
outputs_spec=outputs_spec,
metadata=metadata,
process_label=process_label,
function_data=function_data,
deserializers=deserializers,
serializers=serializers,
register_pickle_by_value=register_pickle_by_value,
**kwargs,
)
inputs["interval"] = orm.Float(interval) if interval is not None else orm.Float(10.0)
inputs["timeout"] = orm.Float(timeout) if timeout is not None else orm.Float(3600.0)
return inputs
32 changes: 32 additions & 0 deletions tests/test_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import datetime

from aiida.engine import run_get_node

from aiida_pythonjob import MonitorPyFunction, prepare_monitor_function_inputs


def monitor_time(time: datetime.datetime):
return datetime.datetime.now() > time


def test_async_function_runs_and_returns_result():
inputs = prepare_monitor_function_inputs(
monitor_time,
function_inputs={"time": datetime.datetime.now() + datetime.timedelta(seconds=5)},
)
result, node = run_get_node(MonitorPyFunction, **inputs)
assert node.is_finished_ok
# The actual monitor function returns None
assert result["result"].value is None


def test_async_function_raises_produces_exit_code():
inputs = prepare_monitor_function_inputs(
monitor_time,
function_inputs={"time": datetime.datetime.now() + datetime.timedelta(seconds=20)},
timeout=5.0,
)
_, node = run_get_node(MonitorPyFunction, **inputs)
assert not node.is_finished_ok
assert node.exit_status == MonitorPyFunction.exit_codes.ERROR_TIMEOUT.status
assert "Monitor function execution timed out." in node.exit_message