Skip to content

Commit 94ccf34

Browse files
authored
Do not store pickled function data (#24)
Do not save the pickled_function data anymore; only save the `source_code` as an attribute.
1 parent 7e3cf90 commit 94ccf34

File tree

5 files changed

+44
-41
lines changed

5 files changed

+44
-41
lines changed

src/aiida_pythonjob/calculations/pyfunction.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,18 @@ def __init__(self, *args, **kwargs) -> None:
3838

3939
@property
4040
def func(self) -> t.Callable[..., t.Any]:
41+
import cloudpickle
42+
4143
if self._func is None:
42-
self._func = self.inputs.function_data.pickled_function.value
44+
self._func = cloudpickle.loads(self.inputs.function_data.pickled_function)
4345
return self._func
4446

4547
@classmethod
4648
def define(cls, spec: ProcessSpec) -> None: # type: ignore[override]
4749
"""Define the process specification, including its inputs, outputs and known exit codes."""
4850
super().define(spec)
49-
spec.input_namespace("function_data")
50-
spec.input("function_data.name", valid_type=Str, serializer=to_aiida_type)
51-
spec.input("function_data.source_code", valid_type=Str, serializer=to_aiida_type, required=False)
51+
spec.input_namespace("function_data", dynamic=True, required=True)
5252
spec.input("function_data.outputs", valid_type=List, serializer=to_aiida_type, required=False)
53-
spec.input("function_data.pickled_function", valid_type=Data, required=False)
54-
spec.input("function_data.mode", valid_type=Str, serializer=to_aiida_type, required=False)
5553
spec.input("process_label", valid_type=Str, serializer=to_aiida_type, required=False)
5654
spec.input_namespace("function_inputs", valid_type=Data, required=False)
5755
spec.input(
@@ -88,10 +86,10 @@ def define(cls, spec: ProcessSpec) -> None: # type: ignore[override]
8886
def get_function_name(self) -> str:
8987
"""Return the name of the function to run."""
9088
if "name" in self.inputs.function_data:
91-
name = self.inputs.function_data.name.value
89+
name = self.inputs.function_data.name
9290
else:
9391
try:
94-
name = self.inputs.function_data.pickled_function.value.__name__
92+
name = self.func.__name__
9593
except AttributeError:
9694
# If a user doesn't specify name, fallback to something generic
9795
name = "anonymous_function"

src/aiida_pythonjob/calculations/pythonjob.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from aiida.common.datastructures import CalcInfo, CodeInfo
99
from aiida.common.folders import Folder
10+
from aiida.common.lang import override
1011
from aiida.engine import CalcJob, CalcJobProcessSpec
1112
from aiida.orm import (
1213
Data,
@@ -37,17 +38,14 @@ class PythonJob(CalcJob):
3738
_DEFAULT_INPUT_FILE = "script.py"
3839
_DEFAULT_OUTPUT_FILE = "aiida.out"
3940
_DEFAULT_PARENT_FOLDER_NAME = "./parent_folder/"
41+
_SOURCE_CODE_KEY = "source_code"
4042

4143
@classmethod
4244
def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
4345
"""Define the process specification, including its inputs, outputs and known exit codes."""
4446
super().define(spec)
45-
spec.input_namespace("function_data")
46-
spec.input("function_data.name", valid_type=Str, serializer=to_aiida_type)
47-
spec.input("function_data.source_code", valid_type=Str, serializer=to_aiida_type, required=False)
47+
spec.input_namespace("function_data", dynamic=True, required=True)
4848
spec.input("function_data.outputs", valid_type=List, serializer=to_aiida_type, required=False)
49-
spec.input("function_data.pickled_function", valid_type=Data, required=False)
50-
spec.input("function_data.mode", valid_type=Str, serializer=to_aiida_type, required=False)
5149
spec.input("process_label", valid_type=Str, serializer=to_aiida_type, required=False)
5250
spec.input_namespace("function_inputs", valid_type=Data, required=False)
5351
spec.input(
@@ -175,13 +173,9 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
175173
def get_function_name(self) -> str:
176174
"""Return the name of the function to run."""
177175
if "name" in self.inputs.function_data:
178-
name = self.inputs.function_data.name.value
176+
name = self.inputs.function_data.name
179177
else:
180-
try:
181-
name = self.inputs.function_data.pickled_function.value.__name__
182-
except AttributeError:
183-
# If a user doesn't specify name, fallback to something generic
184-
name = "anonymous_function"
178+
name = "anonymous_function"
185179
return name
186180

187181
def _build_process_label(self) -> str:
@@ -192,6 +186,13 @@ def _build_process_label(self) -> str:
192186
name = self.get_function_name()
193187
return f"PythonJob<{name}>"
194188

189+
@override
190+
def _setup_db_record(self) -> None:
191+
"""Set up the database record for the process."""
192+
super()._setup_db_record()
193+
if "source_code" in self.inputs.function_data:
194+
self.node.base.attributes.set(self._SOURCE_CODE_KEY, self.inputs.function_data.source_code)
195+
195196
def on_create(self) -> None:
196197
"""Called when a Process is created."""
197198
super().on_create()
@@ -223,19 +224,13 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
223224
else:
224225
parent_folder_name = self._DEFAULT_PARENT_FOLDER_NAME
225226

226-
function_data = self.inputs.function_data
227-
228227
# Build the Python script
229-
source_code = function_data.get("source_code")
230-
if "pickled_function" in self.inputs.function_data:
231-
pickled_function = self.inputs.function_data.pickled_function.get_serialized_value()
232-
else:
233-
pickled_function = None
234-
# Generate script.py content
228+
source_code = self.node.base.attributes.get(self._SOURCE_CODE_KEY, None)
229+
pickled_function = self.inputs.function_data.pickled_function
235230
function_name = self.get_function_name() # or some user-defined name
236231
script_content = generate_script_py(
237232
pickled_function=pickled_function,
238-
source_code=source_code.value if source_code else None,
233+
source_code=source_code,
239234
function_name=function_name,
240235
)
241236

src/aiida_pythonjob/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ def inspect_function(
5454
# the source code is not saved in the pickle file
5555
import cloudpickle
5656

57-
from aiida_pythonjob.data.pickled_data import PickledData
58-
5957
if inspect_source:
6058
try:
6159
source_code = inspect.getsource(func)
@@ -70,10 +68,10 @@ def inspect_function(
7068
if register_pickle_by_value:
7169
module = importlib.import_module(func.__module__)
7270
cloudpickle.register_pickle_by_value(module)
73-
pickled_function = PickledData(value=func)
71+
pickled_function = cloudpickle.dumps(func)
7472
cloudpickle.unregister_pickle_by_value(module)
7573
else:
76-
pickled_function = PickledData(value=func)
74+
pickled_function = cloudpickle.dumps(func)
7775

7876
return {"source_code": source_code, "mode": "use_pickled_function", "pickled_function": pickled_function}
7977

tests/test_pythonjob.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,22 @@ def add(x, y):
310310
)
311311
result, node = run_get_node(PythonJob, **inputs)
312312
assert result["result"].value == 8
313+
314+
315+
@pytest.mark.usefixtures("started_daemon_client")
316+
def test_submit(fixture_localhost):
317+
"""Test decorator."""
318+
from aiida.engine import submit
319+
320+
def add(x, y):
321+
return x + y
322+
323+
inputs = prepare_pythonjob_inputs(
324+
add,
325+
function_inputs={"x": 1, "y": 2},
326+
process_label="add",
327+
)
328+
node = submit(PythonJob, **inputs, wait=True)
329+
330+
assert node.outputs.result.value == 3
331+
assert node.process_label == "add"

tests/test_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,12 @@ def test_build_function_data():
1212
assert function_data["name"] == "build_function_data"
1313
assert "source_code" in function_data
1414
assert "pickled_function" in function_data
15-
node = function_data["pickled_function"]
16-
with node.base.repository.open(node.FILENAME, mode="rb") as f:
17-
text = f.read()
18-
assert b"cloudpickle" not in text
19-
15+
assert b"cloudpickle" not in function_data["pickled_function"]
2016
function_data = build_function_data(build_function_data, register_pickle_by_value=True)
2117
assert function_data["name"] == "build_function_data"
2218
assert "source_code" in function_data
2319
assert "pickled_function" in function_data
24-
node = function_data["pickled_function"]
25-
with node.base.repository.open(node.FILENAME, mode="rb") as f:
26-
text = f.read()
27-
assert b"cloudpickle" in text
20+
assert b"cloudpickle" in function_data["pickled_function"]
2821

2922
def local_function(x, y):
3023
return x + y

0 commit comments

Comments
 (0)