Skip to content

Commit 30383c6

Browse files
authored
Replace the dict-based way of declaring sockets and namespaces with the spec API (#35)
This commit replaces the old dict-based way of declaring sockets and namespaces with the `spec` API from `node-graph`. In this way, one can declare the - top-level dynamic outputs, `output_sepc=spec.dynamic(any)` - dynamic of namespace, `output_spec=spec.dynamic(spec.namespace(sum=int))`
1 parent a70e9a0 commit 30383c6

File tree

14 files changed

+426
-400
lines changed

14 files changed

+426
-400
lines changed

docs/gallery/autogen/pyfunction.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
######################################################################
88
# Default outputs
9-
# --------------
9+
# -----------------
1010
#
1111
# The default output of the function is `result`. The `pyfunction` task
1212
# will store the result as one node in the database with the key `result`.
1313
#
1414
from aiida import load_profile
1515
from aiida.engine import run_get_node
16-
from aiida_pythonjob import pyfunction
16+
from aiida_pythonjob import pyfunction, spec
1717

1818
load_profile()
1919

@@ -35,7 +35,7 @@ def add(x, y):
3535
#
3636

3737

38-
@pyfunction(outputs=[{"name": "sum"}, {"name": "diff"}])
38+
@pyfunction(outputs=spec.namespace(sum=any, diff=any))
3939
def add(x, y):
4040
return {"sum": x + y, "diff": x - y}
4141

@@ -48,7 +48,7 @@ def add(x, y):
4848

4949
######################################################################
5050
# Namespace Output
51-
# --------------
51+
# -----------------
5252
#
5353
# The `pyfunction` allows users to define namespace outputs. A namespace output
5454
# is a dictionary with keys and values returned by a function. Each value in
@@ -70,20 +70,20 @@ def add(x, y):
7070
from ase.build import bulk # noqa: E402
7171

7272

73-
@pyfunction(outputs=[{"name": "scaled_structures", "identifier": "namespace"}])
73+
@pyfunction(outputs=spec.dynamic(Atoms))
7474
def generate_structures(structure: Atoms, factor_lst: list) -> dict:
7575
"""Scale the structure by the given factor_lst."""
7676
scaled_structures = {}
7777
for i in range(len(factor_lst)):
7878
atoms = structure.copy()
7979
atoms.set_cell(atoms.cell * factor_lst[i], scale_atoms=True)
8080
scaled_structures[f"s_{i}"] = atoms
81-
return {"scaled_structures": scaled_structures}
81+
return scaled_structures
8282

8383

8484
result, node = run_get_node(generate_structures, structure=bulk("Al"), factor_lst=[0.95, 1.0, 1.05])
8585
print("scaled_structures: ")
86-
for key, value in result["scaled_structures"].items():
86+
for key, value in result.items():
8787
print(key, value)
8888

8989

@@ -115,7 +115,7 @@ def add(x, y):
115115

116116
######################################################################
117117
# Define your data serializer and deserializer
118-
# --------------
118+
# ----------------------------------------------
119119
#
120120
# PythonJob search data serializer from the `aiida.data` entry point by the
121121
# module name and class name (e.g., `ase.atoms.Atoms`).

docs/gallery/autogen/pythonjob.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@
6363

6464
######################################################################
6565
# Default outputs
66-
# --------------
66+
# ----------------
6767
#
6868
# The default output of the function is `result`. The `PythonJob` task
6969
# will store the result as one node in the database with the key `result`.
7070
#
7171
from aiida import load_profile
7272
from aiida.engine import run_get_node
73-
from aiida_pythonjob import PythonJob, prepare_pythonjob_inputs
73+
from aiida_pythonjob import PythonJob, prepare_pythonjob_inputs, spec
7474

7575
load_profile()
7676

@@ -91,7 +91,7 @@ def add(x, y):
9191
# Custom outputs
9292
# --------------
9393
# If the function return a dictionary with fixed number of keys, and you
94-
# want to store the values as separate outputs, you can specify the `output_ports` parameter.
94+
# want to store the values as separate outputs, you can specify the `outputs_spec` parameter.
9595
# For a dynamic number of outputs, you can use the namespace output, which is explained later.
9696
#
9797

@@ -103,10 +103,7 @@ def add(x, y):
103103
inputs = prepare_pythonjob_inputs(
104104
add,
105105
function_inputs={"x": 1, "y": 2},
106-
output_ports=[
107-
{"name": "sum"},
108-
{"name": "diff"},
109-
],
106+
outputs_spec=spec.namespace(sum=any, diff=any),
110107
)
111108
result, node = run_get_node(PythonJob, **inputs)
112109

@@ -117,7 +114,7 @@ def add(x, y):
117114

118115
######################################################################
119116
# Using parent folder
120-
# --------------
117+
# -----------------------
121118
# The parent_folder parameter allows a task to access the output files of
122119
# a parent task. This feature is particularly useful when you want to reuse
123120
# data generated by a previous computation in subsequent computations. In
@@ -142,15 +139,13 @@ def multiply(x, y):
142139
inputs1 = prepare_pythonjob_inputs(
143140
add,
144141
function_inputs={"x": 1, "y": 2},
145-
output_ports=[{"name": "sum"}],
146142
)
147143

148144
result1, node1 = run_get_node(PythonJob, inputs=inputs1)
149145

150146
inputs2 = prepare_pythonjob_inputs(
151147
multiply,
152148
function_inputs={"x": 1, "y": 2},
153-
output_ports=[{"name": "product"}],
154149
parent_folder=result1["remote_folder"],
155150
)
156151

@@ -160,7 +155,7 @@ def multiply(x, y):
160155

161156
######################################################################
162157
# Upload files or folders to the remote computer
163-
# --------------
158+
# -------------------------------------------------
164159
# The `upload_files` parameter allows users to upload files or folders to
165160
# the remote computer. The files will be uploaded to the working directory of the remote computer.
166161
#
@@ -202,7 +197,7 @@ def add():
202197

203198
######################################################################
204199
# Retrieve additional files from the remote computer
205-
# --------------
200+
# ----------------------------------------------------
206201
# Sometimes, one may want to retrieve additional files from the remote
207202
# computer after the job has finished. For example, one may want to retrieve
208203
# the output files generated by the `pw.x` calculation in Quantum ESPRESSO.
@@ -235,7 +230,7 @@ def add(x, y):
235230

236231
######################################################################
237232
# Namespace Output
238-
# --------------
233+
# ------------------
239234
#
240235
# The `PythonJob` allows users to define namespace outputs. A namespace output
241236
# is a dictionary with keys and values returned by a function. Each value in
@@ -264,18 +259,18 @@ def generate_structures(structure: Atoms, factor_lst: list) -> dict:
264259
atoms = structure.copy()
265260
atoms.set_cell(atoms.cell * factor_lst[i], scale_atoms=True)
266261
scaled_structures[f"s_{i}"] = atoms
267-
return {"scaled_structures": scaled_structures}
262+
return scaled_structures
268263

269264

270265
inputs = prepare_pythonjob_inputs(
271266
generate_structures,
272267
function_inputs={"structure": bulk("Al"), "factor_lst": [0.95, 1.0, 1.05]},
273-
output_ports=[{"name": "scaled_structures", "identifier": "namespace"}],
268+
outputs_spec=spec.dynamic(Atoms),
274269
)
275270

276271
result, node = run_get_node(PythonJob, inputs=inputs)
277272
print("scaled_structures: ")
278-
for key, value in result["scaled_structures"].items():
273+
for key, value in result.items():
279274
print(key, value)
280275

281276

@@ -297,31 +292,20 @@ def generate_structures(structure: Atoms, factor_lst: list) -> dict:
297292
scaled_structures[f"s_{i}"] = atoms
298293
volumes[f"v_{i}"] = atoms.get_volume()
299294
return {
300-
"outputs": {
301-
"scaled_structures": scaled_structures,
302-
"volume": volumes,
303-
}
295+
"scaled_structures": scaled_structures,
296+
"volume": volumes,
304297
}
305298

306299

307300
inputs = prepare_pythonjob_inputs(
308301
generate_structures,
309302
function_inputs={"structure": bulk("Al"), "factor_lst": [0.95, 1.0, 1.05]},
310-
output_ports=[
311-
{
312-
"name": "outputs",
313-
"identifier": "namespace",
314-
"ports": [
315-
{"name": "scaled_structures", "identifier": "namespace"},
316-
{"name": "volume", "identifier": "namespace"},
317-
],
318-
}
319-
],
303+
outputs_spec=spec.namespace(scaled_structures=spec.dynamic(Atoms), volume=spec.dynamic(float)),
320304
)
321305

322306
result, node = run_get_node(PythonJob, inputs=inputs)
323-
print("result: ", result["outputs"]["scaled_structures"])
324-
print("volumes: ", result["outputs"]["volume"])
307+
print("result: ", result["scaled_structures"])
308+
print("volumes: ", result["volume"])
325309

326310

327311
######################################################################
@@ -420,7 +404,7 @@ def add(x, y):
420404

421405
######################################################################
422406
# Define your data serializer and deserializer
423-
# --------------
407+
# ----------------------------------------------
424408
#
425409
# PythonJob search data serializer from the `aiida.data` entry point by the
426410
# module name and class name (e.g., `ase.atoms.Atoms`).

examples/test_add.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ dependencies = [
2424
"aiida-core>=2.3,<3",
2525
"ase",
2626
"cloudpickle",
27+
"node-graph==0.2.22",
2728
]
2829

2930
[project.optional-dependencies]
@@ -163,3 +164,6 @@ features = ["docs"]
163164
build = [
164165
"make -C docs"
165166
]
167+
168+
[tool.hatch.metadata]
169+
allow-direct-references = true

src/aiida_pythonjob/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
__version__ = "0.2.5"
44

5+
from node_graph import spec
6+
57
from .calculations import PythonJob
68
from .decorator import pyfunction
79
from .launch import prepare_pythonjob_inputs
@@ -13,4 +15,5 @@
1315
"PickledData",
1416
"prepare_pythonjob_inputs",
1517
"PythonJobParser",
18+
"spec",
1619
)

src/aiida_pythonjob/calculations/pyfunction.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import traceback
66
import typing as t
77

8+
import cloudpickle
9+
import plumpy
810
from aiida.common.lang import override
911
from aiida.engine import Process, ProcessSpec
1012
from aiida.engine.processes.exit_code import ExitCode
@@ -30,10 +32,18 @@ def __init__(self, *args, **kwargs) -> None:
3032
super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore[misc]
3133
self._func = None
3234

35+
@override
36+
def load_instance_state(
37+
self, saved_state: t.MutableMapping[str, t.Any], load_context: plumpy.persistence.LoadSaveContext
38+
) -> None:
39+
"""Load the instance state from the saved state."""
40+
41+
super().load_instance_state(saved_state, load_context)
42+
# Restore the function from the pickled data
43+
self._func = cloudpickle.loads(self.inputs.function_data.pickled_function)
44+
3345
@property
3446
def func(self) -> t.Callable[..., t.Any]:
35-
import cloudpickle
36-
3747
if self._func is None:
3848
self._func = cloudpickle.loads(self.inputs.function_data.pickled_function)
3949
return self._func
@@ -189,7 +199,8 @@ def parse(self, results):
189199
if exit_code:
190200
return exit_code
191201
# Store the outputs
192-
for output in self.output_ports["ports"]:
193-
self.out(output["name"], output["value"])
202+
for name, port in self.output_ports["ports"].items():
203+
if "value" in port:
204+
self.out(name, port["value"])
194205

195206
return ExitCode()

src/aiida_pythonjob/decorator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, "ProcessNode
6060
manager = get_manager()
6161
runner = manager.get_runner()
6262
# # Remove all the known inputs from the kwargs
63-
output_ports = kwargs.pop("output_ports", None) or outputs
64-
input_ports = kwargs.pop("input_ports", None) or inputs
63+
outputs_spec = kwargs.pop("outputs_spec", None) or outputs
64+
inputs_spec = kwargs.pop("inputs_spec", None) or inputs
65+
input_ports = kwargs.pop("input_ports", None)
66+
output_ports = kwargs.pop("output_ports", None)
6567
metadata = kwargs.pop("metadata", None)
6668
function_data = kwargs.pop("function_data", None)
6769
deserializers = kwargs.pop("deserializers", None)
@@ -73,6 +75,8 @@ def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, "ProcessNode
7375
process_inputs = prepare_pyfunction_inputs(
7476
function=function,
7577
function_inputs=function_inputs,
78+
inputs_spec=inputs_spec,
79+
outputs_spec=outputs_spec,
7680
input_ports=input_ports,
7781
output_ports=output_ports,
7882
metadata=metadata,

0 commit comments

Comments
 (0)