Skip to content

Commit f5199fe

Browse files
tomMoralogriseljeremiedbb
authored
MTN deterministic co_filename for dynamic code pickling (#560)
Co-authored-by: Olivier Grisel <[email protected]> Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent 3a878eb commit f5199fe

File tree

5 files changed

+117
-5
lines changed

5 files changed

+117
-5
lines changed

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ jobs:
6969
shell: bash
7070
run: |
7171
COVERAGE_PROCESS_START=$GITHUB_WORKSPACE/.coveragerc \
72-
PYTHONPATH='.:tests' python -m pytest -r s
72+
PYTHONPATH='.:tests' python -m pytest -r s -vs
7373
coverage combine --append
7474
coverage xml -i
7575
- name: Publish coverage results

CHANGES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
In development
22
==============
33

4+
- Make pickling of functions depending on globals in notebook more
5+
deterministic. ([PR#560](https://github.com/cloudpipe/cloudpickle/pull/560))
6+
47
3.1.2
58
=====
69

cloudpickle/cloudpickle.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,13 @@ def _code_reduce(obj):
837837
# See the inline comment in _class_setstate for details.
838838
co_name = "".join(obj.co_name)
839839

840+
# co_filename is not used in the constructor of code objects, so we can
841+
# safely set it to indicate that this is dynamic code. This also makes
842+
# the payload deterministic, independent of where the function is defined
843+
# which is especially useful when defining classes in jupyter/ipython
844+
# cells which do not have a deterministic filename.
845+
co_filename = "".join("<dynamic-code>")
846+
840847
# Create shallow copies of these tuple to make cloudpickle payload deterministic.
841848
# When creating a code object during load, copies of these four tuples are
842849
# created, while in the main process, these tuples can be shared.
@@ -859,7 +866,7 @@ def _code_reduce(obj):
859866
obj.co_consts,
860867
co_names,
861868
co_varnames,
862-
obj.co_filename,
869+
co_filename,
863870
co_name,
864871
obj.co_qualname,
865872
obj.co_firstlineno,
@@ -882,7 +889,7 @@ def _code_reduce(obj):
882889
obj.co_consts,
883890
co_names,
884891
co_varnames,
885-
obj.co_filename,
892+
co_filename,
886893
co_name,
887894
obj.co_firstlineno,
888895
obj.co_linetable,
@@ -903,7 +910,7 @@ def _code_reduce(obj):
903910
obj.co_code,
904911
obj.co_consts,
905912
co_varnames,
906-
obj.co_filename,
913+
co_filename,
907914
co_name,
908915
obj.co_firstlineno,
909916
obj.co_lnotab,
@@ -927,7 +934,7 @@ def _code_reduce(obj):
927934
obj.co_consts,
928935
co_names,
929936
co_varnames,
930-
obj.co_filename,
937+
co_filename,
931938
co_name,
932939
obj.co_firstlineno,
933940
obj.co_lnotab,

dev-requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ pytest-cov
77
psutil
88
# To be able to test tornado coroutines
99
tornado
10+
# To be able to test behavior in jupyter-notebooks
11+
ipykernel
1012
# To be able to test numpy specific things
1113
# but do not build numpy from source on Python nightly
1214
numpy >=1.18.5; python_version <= '3.12'
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import sys
2+
import time
3+
import pytest
4+
import platform
5+
import textwrap
6+
from queue import Empty
7+
8+
from .testutils import check_deterministic_pickle
9+
10+
if sys.platform == "win32":
11+
if sys.version_info < (3, 11):
12+
pytest.skip(
13+
"ipykernel requires Python 3.11 or later",
14+
allow_module_level=True
15+
)
16+
ipykernel = pytest.importorskip("ipykernel")
17+
18+
19+
def run_in_notebook(code, timeout=10):
20+
21+
km = ipykernel.connect.jupyter_client.KernelManager()
22+
km.start_kernel()
23+
kc = km.client()
24+
kc.start_channels()
25+
status, output, err = "kernel_started", None, None
26+
try:
27+
assert km.is_alive() and kc.is_alive()
28+
kc.wait_for_ready()
29+
idx = kc.execute(code)
30+
running = True
31+
while running:
32+
try:
33+
res = kc.iopub_channel.get_msg(timeout=timeout)
34+
except Empty:
35+
status = "timeout"
36+
break
37+
if res['parent_header'].get('msg_id') != idx:
38+
continue
39+
content = res['content']
40+
if content.get("name", "state") == "stdout":
41+
output = content['text']
42+
if "traceback" in content:
43+
err = "\n".join(content['traceback'])
44+
status = "error"
45+
running = res['content'].get('execution_state', None) != "idle"
46+
finally:
47+
kc.shutdown()
48+
kc.stop_channels()
49+
km.shutdown_kernel(now=True, restart=False)
50+
assert not km.is_alive()
51+
if status not in ["error", "timeout"]:
52+
status = "ok" if not running else "exec_error"
53+
return status, output, err
54+
55+
56+
@pytest.mark.skipif(
57+
platform.python_implementation() == "PyPy",
58+
reason="Skip PyPy because tests are too slow",
59+
)
60+
@pytest.mark.parametrize("code, expected", [
61+
("1 + 1", "ok"),
62+
("raise ValueError('This is a test error')", "error"),
63+
("import time; time.sleep(100)", "timeout")
64+
65+
])
66+
def test_run_in_notebook(code, expected):
67+
code = textwrap.dedent(code)
68+
69+
t_start = time.time()
70+
status, output, err = run_in_notebook(code, timeout=1)
71+
duration = time.time() - t_start
72+
assert status == expected, (
73+
f"Unexpected status: {status}, output: {output}, err: {err}, duration: {duration}"
74+
)
75+
assert duration < 10, "Timeout not enforced properly"
76+
if expected == "error":
77+
assert "This is a test error" in err
78+
79+
80+
def test_deterministic_payload_for_dynamic_func_in_notebook():
81+
code = textwrap.dedent("""
82+
import cloudpickle
83+
84+
MY_PI = 3.1415
85+
86+
def get_pi():
87+
return MY_PI
88+
89+
print(cloudpickle.dumps(get_pi))
90+
""")
91+
92+
status, output, err = run_in_notebook(code)
93+
assert status == "ok"
94+
payload = eval(output.strip(), {})
95+
96+
status, output, err = run_in_notebook(code)
97+
assert status == "ok"
98+
payload2 = eval(output.strip(), {})
99+
100+
check_deterministic_pickle(payload, payload2)

0 commit comments

Comments
 (0)