Skip to content

Commit 7e3cf90

Browse files
authored
pyfunction handles already serialized results (#23)
1 parent 88f9298 commit 7e3cf90

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

src/aiida_pythonjob/calculations/pyfunction.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,10 @@ def run(self) -> ExitCode | None:
202202
if exit_code.status != 0:
203203
return exit_code
204204
if len(top_level_output_list) == 1:
205-
# If output name in results, use it
206-
if top_level_output_list[0]["name"] in results:
205+
# User returned a single (nested) dict with AiiDA data nodes as values
206+
if self.already_serialized(results):
207+
top_level_output_list = [{"name": key, "value": value} for key, value in results.items()]
208+
elif top_level_output_list[0]["name"] in results:
207209
top_level_output_list[0]["value"] = self.serialize_output(
208210
results.pop(top_level_output_list[0]["name"]),
209211
top_level_output_list[0],
@@ -228,8 +230,14 @@ def run(self) -> ExitCode | None:
228230
if len(results) > 0:
229231
self.logger.warning(f"Found extra results that are not included in the output: {results.keys()}")
230232
elif len(top_level_output_list) == 1:
231-
# Single top-level output, single result
232-
top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0])
233+
# Single top-level output
234+
# There are two cases:
235+
# 1. The output as a whole will be serialized as the single output
236+
# 2. The output is a mapping with already AiiDA data nodes as values, no need to serialize
237+
if self.already_serialized(results):
238+
top_level_output_list[0]["value"] = results
239+
else:
240+
top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0])
233241
else:
234242
return self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH
235243
# Store the outputs
@@ -238,6 +246,20 @@ def run(self) -> ExitCode | None:
238246

239247
return ExitCode()
240248

249+
def already_serialized(self, results):
250+
"""Check if the results are already serialized."""
251+
import collections
252+
253+
if isinstance(results, Data):
254+
return True
255+
elif isinstance(results, collections.abc.Mapping):
256+
for value in results.values():
257+
if not self.already_serialized(value):
258+
return False
259+
return True
260+
else:
261+
return False
262+
241263
def find_output(self, name):
242264
"""Find the output spec with the given name."""
243265
for output in self.output_list:

tests/test_pyfunction.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from aiida import orm
12
from aiida.engine import run_get_node
23
from aiida_pythonjob import pyfunction
34

@@ -118,3 +119,15 @@ def myfunc(x, y):
118119
assert result["add_multiply"]["add"]["order1"].value == 3
119120
assert result["add_multiply"]["add"]["order2"].value == 5
120121
assert result["add_multiply"]["multiply"].value == 2
122+
123+
124+
def test_aiida_node_as_inputs_outputs():
125+
"""Test function with AiiDA nodes as inputs and outputs."""
126+
127+
@pyfunction()
128+
def add(x, y):
129+
return {"sum": orm.Int(x + y), "diff": orm.Int(x - y)}
130+
131+
result, node = run_get_node(add, x=orm.Int(1), y=orm.Int(2))
132+
assert set(result.keys()) == {"sum", "diff"}
133+
assert result["sum"].value == 3

0 commit comments

Comments
 (0)