Skip to content

Commit 083f592

Browse files
authored
Add custom de-serializers and serializers (#15)
This PR allows the users to set custom de-serializers and serializers either as input or in the pythonjob.json configuration file.
1 parent 16c854c commit 083f592

File tree

11 files changed

+255
-84
lines changed

11 files changed

+255
-84
lines changed

docs/gallery/autogen/how_to.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def add(x, y):
349349

350350

351351
######################################################################
352-
# Define your data serializer
352+
# Define your data serializer and deserializer
353353
# --------------
354354
#
355355
# PythonJob search data serializer from the `aiida.data` entry point by the
@@ -376,13 +376,54 @@ def add(x, y):
376376
#
377377
# {
378378
# "serializers": {
379-
# "ase.atoms.Atoms": "abc.ase.atoms.Atoms"
379+
# "ase.atoms.Atoms": "abc.ase.atoms.AtomsData" # use the full path to the serializer
380380
# }
381381
# }
382382
#
383383
# Save the configuration file as `pythonjob.json` in the aiida configuration
384384
# directory (by default, `~/.aiida` directory).
385+
#
386+
# If you want to pass AiiDA Data node as input, and the node does not have a `value` attribute,
387+
# then one must provide a deserializer for it.
388+
#
389+
390+
from aiida import orm # noqa: E402
391+
392+
393+
def make_supercell(structure, n=2):
394+
return structure * [n, n, n]
395+
396+
397+
structure = orm.StructureData(cell=[[1, 0, 0], [0, 1, 0], [0, 0, 1]])
398+
structure.append_atom(position=(0.0, 0.0, 0.0), symbols="Li")
399+
400+
inputs = prepare_pythonjob_inputs(
401+
make_supercell,
402+
function_inputs={"structure": structure},
403+
deserializers={
404+
"aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms"
405+
},
406+
)
407+
result, node = run_get_node(PythonJob, inputs=inputs)
408+
print("result: ", result["result"])
385409

410+
######################################################################
411+
# One can also set the deserializer in the configuration file.
412+
#
413+
#
414+
# .. code-block:: json
415+
#
416+
# {
417+
# "serializers": {
418+
# "ase.atoms.Atoms": "abc.ase.atoms.Atoms"
419+
# },
420+
# "deserializers": {
421+
# "aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_pymatgen" # noqa
422+
# }
423+
# }
424+
#
425+
# The `orm.List`, `orm.Dict`and `orm.StructureData` data types already have built-in deserializers.
426+
#
386427

387428
######################################################################
388429
# What's Next

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ Source = "https://github.com/aiidateam/aiida-pythonjob"
5252
"pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float"
5353
"pythonjob.builtins.str" = "aiida.orm.nodes.data.str:Str"
5454
"pythonjob.builtins.bool" = "aiida.orm.nodes.data.bool:Bool"
55-
"pythonjob.builtins.list"="aiida_pythonjob.data.data_with_value:List"
56-
"pythonjob.builtins.dict"="aiida_pythonjob.data.data_with_value:Dict"
55+
"pythonjob.builtins.list"="aiida.orm.nodes.data.list:List"
56+
"pythonjob.builtins.dict"="aiida.orm.nodes.data.dict:Dict"
5757

5858

5959
[project.entry-points."aiida.calculations"]

src/aiida_pythonjob/calculations/pythonjob.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import typing as t
77

88
from aiida.common.datastructures import CalcInfo, CodeInfo
9-
from aiida.common.extendeddicts import AttributeDict
109
from aiida.common.folders import Folder
1110
from aiida.engine import CalcJob, CalcJobProcessSpec
1211
from aiida.orm import (
1312
Data,
13+
Dict,
1414
FolderData,
1515
List,
1616
RemoteData,
@@ -92,6 +92,22 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
9292
serializer=to_aiida_type,
9393
help="Additional filenames to retrieve from the remote work directory",
9494
)
95+
spec.input(
96+
"deserializers",
97+
valid_type=Dict,
98+
default=None,
99+
required=False,
100+
serializer=to_aiida_type,
101+
help="The deserializers to convert the input AiiDA data nodes to raw Python data.",
102+
)
103+
spec.input(
104+
"serializers",
105+
valid_type=Dict,
106+
default=None,
107+
required=False,
108+
serializer=to_aiida_type,
109+
help="The serializers to convert the raw Python data to AiiDA data nodes.",
110+
)
95111
spec.outputs.dynamic = True
96112
# set default options (optional)
97113
spec.inputs["metadata"]["options"]["parser_name"].default = "pythonjob.pythonjob"
@@ -190,6 +206,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
190206
import cloudpickle as pickle
191207

192208
from aiida_pythonjob.calculations.utils import generate_script_py
209+
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data
193210

194211
dirpath = pathlib.Path(folder._abspath)
195212

@@ -279,17 +296,13 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
279296

280297
# Create a pickle file for the user input values
281298
input_values = {}
282-
for key, value in inputs.items():
283-
if isinstance(value, Data) and hasattr(value, "value"):
284-
input_values[key] = value.value
285-
elif isinstance(value, (AttributeDict, dict)):
286-
# Convert an AttributeDict/dict with .value items
287-
input_values[key] = {k: v.value for k, v in value.items()}
288-
else:
289-
raise ValueError(
290-
f"Input data {value} is not supported. Only AiiDA Data nodes with a '.value' or "
291-
"AttributeDict/dict-of-Data are allowed."
292-
)
299+
if "deserializers" in self.inputs and self.inputs.deserializers:
300+
deserializers = self.inputs.deserializers.get_dict()
301+
# replace "__dot__" with "." in the keys
302+
deserializers = {k.replace("__dot__", "."): v for k, v in deserializers.items()}
303+
else:
304+
deserializers = None
305+
input_values = deserialize_to_raw_python_data(inputs, deserializers=deserializers)
293306

294307
filename = "inputs.pickle"
295308
with folder.open(filename, "wb") as handle:

src/aiida_pythonjob/data/data_with_value.py

Lines changed: 0 additions & 13 deletions
This file was deleted.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from aiida import common, orm
6+
7+
from aiida_pythonjob.config import load_config
8+
from aiida_pythonjob.utils import import_from_path
9+
10+
builtin_deserializers = {
11+
"aiida.orm.nodes.data.list.List": "aiida_pythonjob.data.deserializer.list_data_to_list",
12+
"aiida.orm.nodes.data.dict.Dict": "aiida_pythonjob.data.deserializer.dict_data_to_dict",
13+
"aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms",
14+
}
15+
16+
17+
def generate_aiida_node_deserializer(data: orm.Node) -> dict:
18+
if isinstance(data, orm.Data):
19+
return data.backend_entity.attributes
20+
elif isinstance(data, (common.extendeddicts.AttributeDict, dict)):
21+
# if the data is an AttributeDict, use it directly
22+
return {k: generate_aiida_node_deserializer(v) for k, v in data.items()}
23+
24+
25+
def list_data_to_list(data):
26+
return data.get_list()
27+
28+
29+
def dict_data_to_dict(data):
30+
return data.get_dict()
31+
32+
33+
def structure_data_to_atoms(structure):
34+
return structure.get_ase()
35+
36+
37+
def structure_data_to_pymatgen(structure):
38+
return structure.get_pymatgen()
39+
40+
41+
def get_deserializer() -> dict:
42+
"""Retrieve the serializer from the entry points."""
43+
configs = load_config()
44+
custom_deserializers = configs.get("deserializers", {})
45+
deserializers = builtin_deserializers.copy()
46+
deserializers.update(custom_deserializers)
47+
return deserializers
48+
49+
50+
all_deserializers = get_deserializer()
51+
52+
53+
def deserialize_to_raw_python_data(data: orm.Node, deserializers: dict | None = None) -> Any:
54+
"""Deserialize the AiiDA data node to an raw Python data."""
55+
56+
updated_deserializers = all_deserializers.copy()
57+
58+
if deserializers is not None:
59+
updated_deserializers.update(deserializers)
60+
61+
if isinstance(data, orm.Data):
62+
if hasattr(data, "value"):
63+
return getattr(data, "value")
64+
data_type = type(data)
65+
ep_key = f"{data_type.__module__}.{data_type.__name__}"
66+
if ep_key in updated_deserializers:
67+
deserializer = import_from_path(updated_deserializers[ep_key])
68+
return deserializer(data)
69+
else:
70+
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
71+
elif isinstance(data, (common.extendeddicts.AttributeDict, dict)):
72+
# if the data is an AttributeDict, use it directly
73+
return {k: deserialize_to_raw_python_data(v, deserializers=deserializers) for k, v in data.items()}

src/aiida_pythonjob/data/serializer.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1+
from __future__ import annotations
2+
13
import sys
24
from importlib.metadata import entry_points
35
from typing import Any
46

57
from aiida import common, orm
68

79
from aiida_pythonjob.config import load_config
10+
from aiida_pythonjob.utils import import_from_path
811

12+
from .deserializer import all_deserializers
913
from .pickled_data import PickledData
1014

1115

12-
def get_serializer_from_entry_points() -> dict:
13-
"""Retrieve the serializer from the entry points."""
14-
# import time
16+
def atoms_to_structure_data(structure):
17+
return orm.StructureData(ase=structure)
1518

16-
# ts = time.time()
17-
configs = load_config()
18-
serializers = configs.get("serializers", {})
19-
excludes = serializers.get("excludes", [])
19+
20+
def get_serializers_from_entry_points() -> dict:
2021
# Retrieve the entry points for 'aiida.data' and store them in a dictionary
2122
eps = entry_points()
2223
if sys.version_info >= (3, 10):
@@ -28,28 +29,39 @@ def get_serializer_from_entry_points() -> dict:
2829
# split the entry point name by first ".", and check the last part
2930
key = ep.name.split(".", 1)[-1]
3031
# skip key without "." because it is not a module name for a data type
31-
if "." not in key or key in excludes:
32+
if "." not in key:
3233
continue
3334
eps.setdefault(key, [])
34-
eps[key].append(ep)
35+
# get the path of the entry point value and replace ":" with "."
36+
eps[key].append(ep.value.replace(":", "."))
37+
return eps
38+
3539

40+
def get_serializers() -> dict:
41+
"""Retrieve the serializer from the entry points."""
42+
# import time
43+
44+
# ts = time.time()
45+
all_serializers = {}
46+
configs = load_config()
47+
custom_serializers = configs.get("serializers", {})
48+
eps = get_serializers_from_entry_points()
3649
# check if there are duplicates
3750
for key, value in eps.items():
3851
if len(value) > 1:
39-
if key in serializers:
40-
eps[key] = [ep for ep in value if ep.name == serializers[key]]
41-
if not eps[key]:
42-
raise ValueError(f"Entry point {serializers[key]} not found for {key}")
43-
else:
44-
msg = f"Duplicate entry points for {key}: {[ep.name for ep in value]}"
52+
if key not in custom_serializers:
53+
msg = f"Duplicate entry points for {key}: {value}. You can specify the one to use in the configuration file." # noqa
4554
raise ValueError(msg)
46-
return eps
55+
all_serializers[key] = value[0]
56+
all_serializers.update(custom_serializers)
57+
# print("Time to get serializer", time.time() - ts)
58+
return all_serializers
4759

4860

49-
eps = get_serializer_from_entry_points()
61+
all_serializers = get_serializers()
5062

5163

52-
def serialize_to_aiida_nodes(inputs: dict) -> dict:
64+
def serialize_to_aiida_nodes(inputs: dict, serializers: dict | None = None, deserializers: dict | None = None) -> dict:
5365
"""Serialize the inputs to a dictionary of AiiDA data nodes.
5466
5567
Args:
@@ -61,7 +73,7 @@ def serialize_to_aiida_nodes(inputs: dict) -> dict:
6173
new_inputs = {}
6274
# save all kwargs to inputs port
6375
for key, data in inputs.items():
64-
new_inputs[key] = general_serializer(data)
76+
new_inputs[key] = general_serializer(data, serializers=serializers, deserializers=deserializers)
6577
return new_inputs
6678

6779

@@ -72,11 +84,24 @@ def clean_dict_key(data):
7284
return data
7385

7486

75-
def general_serializer(data: Any, check_value=True) -> orm.Node:
87+
def general_serializer(
88+
data: Any, serializers: dict | None = None, deserializers: dict | None = None, check_value=True
89+
) -> orm.Node:
7690
"""Serialize the data to an AiiDA data node."""
91+
updated_deserializers = all_deserializers.copy()
92+
if deserializers is not None:
93+
updated_deserializers.update(deserializers)
94+
95+
updated_serializers = all_serializers.copy()
96+
if serializers is not None:
97+
updated_serializers.update(serializers)
98+
7799
if isinstance(data, orm.Data):
78100
if check_value and not hasattr(data, "value"):
79-
raise ValueError("Only AiiDA data Node with a value attribute is allowed.")
101+
data_type = type(data)
102+
ep_key = f"{data_type.__module__}.{data_type.__name__}"
103+
if ep_key not in updated_deserializers:
104+
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
80105
return data
81106
elif isinstance(data, common.extendeddicts.AttributeDict):
82107
# if the data is an AttributeDict, use it directly
@@ -92,9 +117,10 @@ def general_serializer(data: Any, check_value=True) -> orm.Node:
92117
data_type = type(data)
93118
ep_key = f"{data_type.__module__}.{data_type.__name__}"
94119
# search for the key in the entry points
95-
if ep_key in eps:
120+
if ep_key in updated_serializers:
96121
try:
97-
new_node = eps[ep_key][0].load()(data)
122+
serializer = import_from_path(updated_serializers[ep_key])
123+
new_node = serializer(data)
98124
except Exception as e:
99125
raise ValueError(f"Error in serializing {ep_key}: {e}")
100126
finally:

src/aiida_pythonjob/launch.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def prepare_pythonjob_inputs(
2020
upload_files: Dict[str, str] = {},
2121
process_label: Optional[str] = None,
2222
function_data: dict | None = None,
23+
deserializers: dict | None = None,
24+
serializers: dict | None = None,
2325
**kwargs: Any,
2426
) -> Dict[str, Any]:
2527
"""Prepare the inputs for PythonJob"""
@@ -55,14 +57,21 @@ def prepare_pythonjob_inputs(
5557
code = get_or_create_code(computer=computer, **command_info)
5658
# serialize the kwargs into AiiDA Data
5759
function_inputs = function_inputs or {}
58-
function_inputs = serialize_to_aiida_nodes(function_inputs)
60+
function_inputs = serialize_to_aiida_nodes(function_inputs, serializers=serializers, deserializers=deserializers)
5961
function_data["outputs"] = function_outputs or [{"name": "result"}]
62+
# replace "." with "__dot__" in the keys of a dictionary
63+
if deserializers:
64+
deserializers = orm.Dict({k.replace(".", "__dot__"): v for k, v in deserializers.items()})
65+
if serializers:
66+
serializers = orm.Dict({k.replace(".", "__dot__"): v for k, v in serializers.items()})
6067
inputs = {
6168
"function_data": function_data,
6269
"code": code,
6370
"function_inputs": function_inputs,
6471
"upload_files": new_upload_files,
6572
"metadata": metadata or {},
73+
"deserializers": deserializers,
74+
"serializers": serializers,
6675
**kwargs,
6776
}
6877
if process_label:

0 commit comments

Comments
 (0)