Skip to content

Commit a212e7a

Browse files
authored
PyFunction supports async function (#51)
- Async user function -> schedule via Waiting (interruptable) - Sync user function -> execute now, parse outputs, and finish This feature needs aiida-core>=2.7.
1 parent cbaff61 commit a212e7a

File tree

14 files changed

+356
-16
lines changed

14 files changed

+356
-16
lines changed

.github/workflows/ci-docs-format.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ jobs:
3434
python-version: '3.12'
3535
cache: 'pip'
3636
- name: Install ruff
37-
run: python -m pip install ruff==0.6.9
37+
run: python -m pip install ruff==0.11.2
3838
- name: Run formatter and linter
3939
run: |
40-
ruff format --check .
41-
ruff check .
40+
ruff format --check . --exclude docs/
41+
ruff check . --exclude docs/

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ repos:
88
- id: trailing-whitespace
99

1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.6.9
11+
rev: v0.11.2
1212
hooks:
1313
- id: ruff
1414
args: ["--fix", "--line-length=121", "--ignore=F821,F722,E203"]

docs/environment.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ channels:
33
- conda-forge
44
- defaults
55
dependencies:
6-
- aiida-core~=2.6.3
76
- aiida-core.services

docs/gallery/autogen/pyfunction.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,72 @@ def generate_structures(element: str, factors: list) -> dict:
7878
print("Generated scaled structures:")
7979
for key, value in result.items():
8080
print(key, value)
81+
82+
83+
# %%
84+
# Async functions
85+
# ---------------
86+
# ``pyfunction`` also supports Python's ``async`` functions. This is a powerful feature for
87+
# tasks that are I/O-bound (e.g., waiting for network requests, file operations) or for
88+
# running multiple tasks concurrently without blocking the AiiDA daemon.
89+
#
90+
# When you ``submit`` an async function, the call returns immediately with a process node,
91+
# allowing your script to continue running while the function executes in the background.
92+
#
93+
94+
from aiida.engine import submit
95+
import datetime
96+
from aiida_pythonjob import prepare_pyfunction_inputs
97+
98+
99+
@pyfunction()
100+
async def add_async(x, y, time: float):
101+
"""A simple function that adds two numbers."""
102+
import asyncio
103+
104+
# Simulate asynchronous I/O or computation
105+
await asyncio.sleep(time)
106+
return x + y
107+
108+
109+
inputs = prepare_pyfunction_inputs(
110+
add_async,
111+
function_inputs={"x": 2, "y": 3, "time": 2.0},
112+
)
113+
114+
node = submit(add_async, **inputs)
115+
116+
# %%
117+
# Here is an example to monitor external events or conditions without blocking.
118+
# Here is an example that waits until a specified time.
119+
#
120+
121+
122+
@pyfunction()
123+
async def monitor_time(time: datetime.datetime):
124+
import asyncio
125+
126+
# monitor until the specified time
127+
while datetime.datetime.now() < time:
128+
print("Waiting...")
129+
await asyncio.sleep(0.5)
130+
131+
132+
inputs = prepare_pyfunction_inputs(
133+
monitor_time,
134+
function_inputs={"time": datetime.datetime.now() + datetime.timedelta(seconds=5)},
135+
)
136+
137+
node = submit(monitor_time, **inputs)
138+
139+
# %%#
140+
# Killing an async process
141+
# ------------------------
142+
# Since async functions run as regular AiiDA processes, they can be controlled and killed
143+
# programmatically. This is useful for managing long-running or stuck tasks.
144+
# You can kill a running async function using the AiiDA command line interface.
145+
#
146+
# .. code-block:: bash
147+
#
148+
# $ verdi process kill <pk>
149+
#

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ classifiers = [
2020
keywords = ["aiida", "plugin"]
2121
requires-python = ">=3.9"
2222
dependencies = [
23-
"aiida-core>=2.3,<3",
23+
"aiida-core>=2.7.1,<3",
2424
"ase",
2525
"node-graph>=0.3.9",
2626
]
@@ -54,6 +54,7 @@ Source = "https://github.com/aiidateam/aiida-pythonjob"
5454
"pythonjob.jsonable_data" = "aiida_pythonjob.data.jsonable_data:JsonableData"
5555
"pythonjob.ase.atoms.Atoms" = "aiida_pythonjob.data.atoms:AtomsData"
5656
"pythonjob.builtins.NoneType" = "aiida_pythonjob.data.common_data:NoneData"
57+
"pythonjob.datetime.datetime" = "aiida_pythonjob.data.common_data:DateTimeData"
5758
"pythonjob.builtins.int" = "aiida.orm.nodes.data.int:Int"
5859
"pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float"
5960
"pythonjob.builtins.str" = "aiida.orm.nodes.data.str:Str"

src/aiida_pythonjob/calculations/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def add_common_function_io(spec) -> None:
4747
320,
4848
"ERROR_DESERIALIZE_INPUTS_FAILED",
4949
invalidates_cache=True,
50-
message="Failed to unpickle inputs.\n{exception}\n{traceback}",
50+
message="Failed to deserialize inputs.\n{exception}\n{traceback}",
5151
)
5252
spec.exit_code(
5353
321,

src/aiida_pythonjob/calculations/pyfunction.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import cloudpickle
99
import plumpy
1010
from aiida.common.lang import override
11-
from aiida.engine import Process, ProcessSpec
11+
from aiida.engine import Process, ProcessSpec, ProcessState
1212
from aiida.engine.processes.exit_code import ExitCode
1313
from aiida.orm import CalcFunctionNode
1414
from node_graph.socket_spec import SocketSpec
@@ -23,12 +23,12 @@
2323
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data
2424
from aiida_pythonjob.parsers.utils import parse_outputs
2525

26+
from .tasks import Waiting
27+
2628
__all__ = ("PyFunction",)
2729

2830

2931
class PyFunction(FunctionProcessMixin, Process):
30-
"""Run a Python function in-process, using :class:`SocketSpec` for I/O."""
31-
3232
_node_class = CalcFunctionNode
3333
label_template = "{name}"
3434
default_name = "anonymous_function"
@@ -76,6 +76,12 @@ def define(cls, spec: ProcessSpec) -> None: # type: ignore[override]
7676
message="Function execution failed.\n{exception}\n{traceback}",
7777
)
7878

79+
@classmethod
80+
def get_state_classes(cls) -> t.Dict[t.Hashable, t.Type[plumpy.process_states.State]]:
81+
states_map = super().get_state_classes()
82+
states_map[ProcessState.WAITING] = Waiting
83+
return states_map
84+
7985
@override
8086
def _setup_db_record(self) -> None:
8187
super()._setup_db_record()
@@ -89,12 +95,19 @@ def execute(self) -> dict[str, t.Any] | None:
8995
return result
9096

9197
@override
92-
def run(self) -> ExitCode | None:
93-
# Respect caching semantics (from aiida-core calcfunction implementation)
98+
async def run(self) -> t.Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]:
99+
import asyncio
100+
94101
if self.node.exit_status is not None:
95102
return ExitCode(self.node.exit_status, self.node.exit_message)
96103

97-
# Deserialize inputs
104+
func = self.func
105+
106+
# Async user function -> schedule via Waiting (interruptable)
107+
if asyncio.iscoroutinefunction(func):
108+
return plumpy.process_states.Wait(msg="Waiting to run")
109+
110+
# Sync user function -> execute now, parse outputs, and finish
98111
try:
99112
inputs = dict(self.inputs.function_inputs or {})
100113
deserializers = self.node.base.attributes.get(ATTR_DESERIALIZERS, {})
@@ -104,17 +117,24 @@ def run(self) -> ExitCode | None:
104117
exception=str(exception), traceback=traceback.format_exc()
105118
)
106119

107-
# Execute function
108120
try:
109121
results = self.func(**inputs)
110122
except Exception as exception:
111123
return self.exit_codes.ERROR_FUNCTION_EXECUTION_FAILED.format(
112124
exception=str(exception), traceback=traceback.format_exc()
113125
)
114126

115-
# Parse & attach outputs
127+
return self.parse(results)
128+
129+
def parse(self, results: t.Optional[dict] = None) -> ExitCode:
130+
"""Parse and attach outputs, or short-circuit with a provided ExitCode."""
131+
# Short-circuit: Waiting handed us a pre-built ExitCode
132+
if isinstance(results, dict) and "__exit_code__" in results:
133+
return results["__exit_code__"]
134+
116135
outputs_spec = SocketSpec.from_dict(self.node.base.attributes.get(ATTR_OUTPUTS_SPEC) or {})
117136
serializers = self.node.base.attributes.get(ATTR_SERIALIZERS, {})
137+
118138
outputs, exit_code = parse_outputs(
119139
results,
120140
output_spec=outputs_spec,
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import functools
5+
import logging
6+
import traceback
7+
from typing import Any, Callable, Optional
8+
9+
import plumpy
10+
import plumpy.futures
11+
import plumpy.persistence
12+
import plumpy.process_states
13+
from aiida.engine.processes.process import Process, ProcessState
14+
from aiida.engine.utils import InterruptableFuture, interruptable_task
15+
16+
from aiida_pythonjob.calculations.common import ATTR_DESERIALIZERS
17+
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
async def task_run_job(process: Process, *args, **kwargs) -> Any:
23+
"""Run the *async* user function and return results or a structured error."""
24+
node = process.node
25+
26+
inputs = dict(process.inputs.function_inputs or {})
27+
deserializers = node.base.attributes.get(ATTR_DESERIALIZERS, {})
28+
inputs = deserialize_to_raw_python_data(inputs, deserializers=deserializers)
29+
30+
try:
31+
logger.info(f"scheduled request to run the function<{node.pk}>")
32+
results = await process.func(**inputs) # async user function
33+
logger.info(f"running function<{node.pk}> successful")
34+
return {"__ok__": True, "results": results}
35+
except Exception as exception:
36+
logger.warning(f"running function<{node.pk}> failed")
37+
return {
38+
"__error__": "ERROR_FUNCTION_EXECUTION_FAILED",
39+
"exception": str(exception),
40+
"traceback": traceback.format_exc(),
41+
}
42+
43+
44+
@plumpy.persistence.auto_persist("msg", "data")
45+
class Waiting(plumpy.process_states.Waiting):
46+
"""The waiting state for the `PyFunction` process."""
47+
48+
def __init__(
49+
self,
50+
process: Process,
51+
done_callback: Optional[Callable[..., Any]],
52+
msg: Optional[str] = None,
53+
data: Optional[Any] = None,
54+
):
55+
super().__init__(process, done_callback, msg, data)
56+
self._task: InterruptableFuture | None = None
57+
self._killing: plumpy.futures.Future | None = None
58+
59+
@property
60+
def process(self) -> Process:
61+
return self.state_machine
62+
63+
def load_instance_state(self, saved_state, load_context):
64+
super().load_instance_state(saved_state, load_context)
65+
self._task = None
66+
self._killing = None
67+
68+
async def execute(self) -> plumpy.process_states.State:
69+
node = self.process.node
70+
node.set_process_status("Running async function")
71+
try:
72+
payload = await self._launch_task(task_run_job, self.process)
73+
74+
# Convert structured payloads into the next state or an ExitCode
75+
if payload.get("__ok__"):
76+
return self.parse(payload["results"])
77+
elif payload.get("__error__"):
78+
err = payload["__error__"]
79+
if err == "ERROR_DESERIALIZE_INPUTS_FAILED":
80+
exit_code = self.process.exit_codes.ERROR_DESERIALIZE_INPUTS_FAILED.format(
81+
exception=payload.get("exception", ""),
82+
traceback=payload.get("traceback", ""),
83+
)
84+
else:
85+
exit_code = self.process.exit_codes.ERROR_FUNCTION_EXECUTION_FAILED.format(
86+
exception=payload.get("exception", ""),
87+
traceback=payload.get("traceback", ""),
88+
)
89+
# Jump straight to FINISHED by scheduling parse with the error ExitCode
90+
# We reuse the Running->parse path so the process finishes uniformly.
91+
return self.create_state(ProcessState.RUNNING, self.process.parse, {"__exit_code__": exit_code})
92+
except plumpy.process_states.KillInterruption as exception:
93+
node.set_process_status(str(exception))
94+
raise
95+
except (plumpy.futures.CancelledError, asyncio.CancelledError):
96+
node.set_process_status('Function task "run" was cancelled')
97+
raise
98+
except plumpy.process_states.Interruption:
99+
node.set_process_status('Function task "run" was interrupted')
100+
raise
101+
finally:
102+
node.set_process_status(None)
103+
if self._killing and not self._killing.done():
104+
self._killing.set_result(False)
105+
106+
async def _launch_task(self, coro, *args, **kwargs):
107+
"""Launch a coroutine as a task, making sure it is interruptable."""
108+
task_fn = functools.partial(coro, *args, **kwargs)
109+
try:
110+
self._task = interruptable_task(task_fn)
111+
return await self._task
112+
finally:
113+
self._task = None
114+
115+
def parse(self, results: dict) -> plumpy.process_states.Running:
116+
"""Advance to RUNNING where the process' `parse` will be called with results."""
117+
return self.create_state(ProcessState.RUNNING, self.process.parse, results)
118+
119+
def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ignore[override]
120+
if self._task is not None:
121+
self._task.interrupt(reason)
122+
if isinstance(reason, plumpy.process_states.KillInterruption):
123+
if self._killing is None:
124+
self._killing = plumpy.futures.Future()
125+
return self._killing
126+
return None
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1+
from .common_data import DateTimeData
12
from .pickled_data import PickledData
23
from .serializer import general_serializer, serialize_to_aiida_nodes
34

4-
__all__ = ("PickledData", "general_serializer", "serialize_to_aiida_nodes")
5+
__all__ = ("DateTimeData", "PickledData", "general_serializer", "serialize_to_aiida_nodes")

src/aiida_pythonjob/data/common_data.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import datetime
2+
13
from aiida import orm
4+
from aiida.orm import Data
25

36

47
class NoneData(orm.Data):
@@ -24,3 +27,22 @@ def __repr__(self) -> str:
2427

2528
def __str__(self) -> str:
2629
return "NoneData()"
30+
31+
32+
class DateTimeData(Data):
33+
"""AiiDA node to store a datetime.datetime object."""
34+
35+
def __init__(self, value: datetime.datetime, **kwargs):
36+
if not isinstance(value, datetime.datetime):
37+
raise TypeError(f"Expected datetime.datetime, got {type(value)}")
38+
super().__init__(**kwargs)
39+
# Store as ISO string for portability
40+
self.base.attributes.set("datetime", value.isoformat())
41+
42+
@property
43+
def value(self) -> datetime.datetime:
44+
"""Return the stored datetime as a datetime object."""
45+
return datetime.datetime.fromisoformat(self.base.attributes.get("datetime"))
46+
47+
def __str__(self):
48+
return str(self.value)

0 commit comments

Comments
 (0)