Skip to content

Commit 6cfca02

Browse files
committed
add register_pickle_by_value
1 parent 839e3ba commit 6cfca02

File tree

5 files changed

+102
-33
lines changed

5 files changed

+102
-33
lines changed

docs/gallery/autogen/how_to.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,33 @@ def add(x, y):
347347
print("exit_status:", node.exit_status)
348348
print("exit_message:", node.exit_message)
349349

350+
######################################################################
351+
# Using `register_pickle_by_value`
352+
# --------------------------------
353+
#
354+
# If the function is defined inside an external module that is **not installed** on
355+
# the remote computer, this can cause import errors during execution.
356+
#
357+
# **Solution:**
358+
# By enabling `register_pickle_by_value=True`, the function is serialized **by value**
359+
# instead of being referenced by its module path. This embeds the function unpickled
360+
# even if the original module is unavailable on the remote computer.
361+
#
362+
# **Example:**
363+
#
364+
# .. code-block:: python
365+
#
366+
# inputs = prepare_pythonjob_inputs(
367+
# my_function,
368+
# function_inputs={"x": 1, "y": 2},
369+
# computer="localhost",
370+
# register_pickle_by_value=True, # Ensures function is embedded
371+
# )
372+
#
373+
# **Important Considerations:**: If the function **contains import statements**,
374+
# the imported modules **must still be installed** on the remote computer.
375+
#
376+
350377

351378
######################################################################
352379
# Define your data serializer and deserializer

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ dependencies = [
2828
]
2929

3030
[project.optional-dependencies]
31+
dev = [
32+
"hatch",
33+
]
3134
pre-commit = [
3235
'pre-commit~=3.5',
3336
]

src/aiida_pythonjob/launch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def prepare_pythonjob_inputs(
2222
function_data: dict | None = None,
2323
deserializers: dict | None = None,
2424
serializers: dict | None = None,
25+
register_pickle_by_value: bool = False,
2526
**kwargs: Any,
2627
) -> Dict[str, Any]:
2728
"""Prepare the inputs for PythonJob"""
@@ -33,7 +34,7 @@ def prepare_pythonjob_inputs(
3334
raise ValueError("Only one of function or function_data should be provided")
3435
# if function is a function, inspect it and get the source code
3536
if function is not None and inspect.isfunction(function):
36-
function_data = build_function_data(function)
37+
function_data = build_function_data(function, register_pickle_by_value=register_pickle_by_value)
3738
new_upload_files = {}
3839
# change the string in the upload files to SingleFileData, or FolderData
3940
for key, source in upload_files.items():

src/aiida_pythonjob/utils.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib
12
import inspect
23
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, _SpecialForm, get_type_hints
34

@@ -6,8 +7,6 @@
67

78

89
def import_from_path(path: str) -> Any:
9-
import importlib
10-
1110
module_name, object_name = path.rsplit(".", 1)
1211
module = importlib.import_module(module_name)
1312
try:
@@ -47,24 +46,39 @@ def add_imports(type_hint):
4746
return imports
4847

4948

50-
def inspect_function(func: Callable) -> Dict[str, Any]:
49+
def inspect_function(
50+
func: Callable, inspect_source: bool = False, register_pickle_by_value: bool = False
51+
) -> Dict[str, Any]:
5152
"""Serialize a function for storage or transmission."""
5253
# we need save the source code explicitly, because in the case of jupyter notebook,
5354
# the source code is not saved in the pickle file
55+
import cloudpickle
56+
5457
from aiida_pythonjob.data.pickled_data import PickledData
5558

56-
try:
57-
source_code = inspect.getsource(func)
58-
# Split the source into lines for processing
59-
source_code_lines = source_code.split("\n")
60-
source_code = "\n".join(source_code_lines)
61-
except OSError:
62-
source_code = "Failed to retrieve source code."
59+
if inspect_source:
60+
try:
61+
source_code = inspect.getsource(func)
62+
# Split the source into lines for processing
63+
source_code_lines = source_code.split("\n")
64+
source_code = "\n".join(source_code_lines)
65+
except OSError:
66+
source_code = "Failed to retrieve source code."
67+
else:
68+
source_code = ""
69+
70+
if register_pickle_by_value:
71+
module = importlib.import_module(func.__module__)
72+
cloudpickle.register_pickle_by_value(module)
73+
pickled_function = PickledData(value=func)
74+
cloudpickle.unregister_pickle_by_value(module)
75+
else:
76+
pickled_function = PickledData(value=func)
6377

64-
return {"source_code": source_code, "mode": "use_pickled_function", "pickled_function": PickledData(value=func)}
78+
return {"source_code": source_code, "mode": "use_pickled_function", "pickled_function": pickled_function}
6579

6680

67-
def build_function_data(func: Callable) -> Dict[str, Any]:
81+
def build_function_data(func: Callable, register_pickle_by_value: bool = False) -> Dict[str, Any]:
6882
"""Inspect the function and return a dictionary with the function data."""
6983
import types
7084

@@ -73,15 +87,10 @@ def build_function_data(func: Callable) -> Dict[str, Any]:
7387
function_data = {"name": func.__name__}
7488
if func.__module__ == "__main__" or "." in func.__qualname__.split(".", 1)[-1]:
7589
# Local or nested callable, so pickle the callable
76-
function_data.update(inspect_function(func))
90+
function_data.update(inspect_function(func, inspect_source=True))
7791
else:
7892
# Global callable (function/class), store its module and name for reference
79-
function_data.update(
80-
{
81-
"mode": "use_module_path",
82-
"source_code": f"from {func.__module__} import {func.__name__}",
83-
}
84-
)
93+
function_data.update(inspect_function(func, register_pickle_by_value=register_pickle_by_value))
8594
else:
8695
raise TypeError("Provided object is not a callable function or class.")
8796
return function_data

tests/test_utils.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,46 @@
1+
import pytest
12
from aiida_pythonjob.utils import build_function_data
23

34

45
def test_build_function_data():
5-
from math import sqrt
6-
7-
function_data = build_function_data(sqrt)
8-
assert function_data == {
9-
"name": "sqrt",
10-
"mode": "use_module_path",
11-
"source_code": "from math import sqrt",
12-
}
13-
#
14-
try:
15-
function_data = build_function_data(1)
16-
except Exception as e:
17-
assert str(e) == "Provided object is not a callable function or class."
6+
"""Test the build_function_data function behavior."""
7+
8+
with pytest.raises(TypeError, match="Provided object is not a callable function or class."):
9+
build_function_data(1)
10+
11+
function_data = build_function_data(build_function_data)
12+
assert function_data["name"] == "build_function_data"
13+
assert "source_code" in function_data
14+
assert "pickled_function" in function_data
15+
node = function_data["pickled_function"]
16+
with node.base.repository.open(node.FILENAME, mode="rb") as f:
17+
text = f.read()
18+
assert b"cloudpickle" not in text
19+
20+
function_data = build_function_data(build_function_data, register_pickle_by_value=True)
21+
assert function_data["name"] == "build_function_data"
22+
assert "source_code" in function_data
23+
assert "pickled_function" in function_data
24+
node = function_data["pickled_function"]
25+
with node.base.repository.open(node.FILENAME, mode="rb") as f:
26+
text = f.read()
27+
assert b"cloudpickle" in text
28+
29+
def local_function(x, y):
30+
return x + y
31+
32+
function_data = build_function_data(local_function)
33+
assert function_data["name"] == "local_function"
34+
assert "source_code" in function_data
35+
assert function_data["mode"] == "use_pickled_function"
36+
37+
def outer_function():
38+
def nested_function(x, y):
39+
return x + y
40+
41+
return nested_function
42+
43+
nested_func = outer_function()
44+
function_data = build_function_data(nested_func)
45+
assert function_data["name"] == "nested_function"
46+
assert function_data["mode"] == "use_pickled_function"

0 commit comments

Comments
 (0)