Skip to content

Commit 5f6e9b3

Browse files
committed
Use standard pickle, and make pickler an option
We used cloudpickle to make it easy to author tests where the test worker is defined in the test function. The problem is that cloudpickle will also try to serialize by value classes defined in the __main__ package, which leads to serialization problems when e.g. the worker class references in its members unpickleable things. Also, revert #452 as it is addressed by this change. Pull Request: #455
1 parent 75eb0fd commit 5f6e9b3

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

compiler_opt/distributed/buffered_scheduler_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Test for buffered_scheduler."""
1515

16+
import cloudpickle
1617
import concurrent.futures
1718
import threading
1819
import time
@@ -35,7 +36,7 @@ def wait_seconds(self, n: int):
3536
return n + 1
3637

3738
with local_worker_manager.LocalWorkerPoolManager(
38-
WaitingWorker, count=2) as pool:
39+
WaitingWorker, count=2, pickle_func=cloudpickle.dumps) as pool:
3940
_, futures = buffered_scheduler.schedule_on_worker_pool(
4041
lambda w, v: w.wait_seconds(v), range(4), pool)
4142
not_done = futures
@@ -54,7 +55,7 @@ def square(self, the_value, extra_factor=1):
5455
return the_value * the_value * extra_factor
5556

5657
with local_worker_manager.LocalWorkerPoolManager(
57-
TheWorker, count=2) as pool:
58+
TheWorker, count=2, pickle_func=cloudpickle.dumps) as pool:
5859
workers, futures = buffered_scheduler.schedule_on_worker_pool(
5960
lambda w, v: w.square(v), range(10), pool)
6061
self.assertLen(workers, 2)

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
"""
2929
import concurrent.futures
3030
import sys
31-
import cloudpickle
3231
import dataclasses
3332
import functools
3433
import multiprocessing
34+
import pickle
3535
import psutil
3636
import threading
3737

@@ -80,7 +80,7 @@ def _run_impl(pipe: connection.Connection, worker_class: SerializedClass,
8080
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
8181
# unpickle the type, which triggers all the necessary imports. This registers
8282
# abseil flags, which can then be parsed.
83-
cls = cloudpickle.loads(worker_class)
83+
cls = pickle.loads(worker_class)
8484
if parse_argv:
8585
flags.FLAGS(sys.argv, known_only=True)
8686
obj = cls(*args, **kwargs)
@@ -126,7 +126,8 @@ def _run(pipe: connection.Connection, worker_class: SerializedClass,
126126
raise e
127127

128128

129-
def _make_stub(cls: 'type[worker.Worker]', parse_argv: bool, *args, **kwargs):
129+
def _make_stub(cls: 'type[worker.Worker]', parse_argv: bool, pickle_func, *args,
130+
**kwargs):
130131

131132
class _Stub:
132133
"""Client stub to a worker hosted by a process."""
@@ -141,7 +142,7 @@ def __init__(self):
141142
# to handle high priority requests. The expectation is that the user
142143
# achieves concurrency through multiprocessing, not multithreading.
143144
self._process = _get_context().Process(
144-
target=functools.partial(_run, child_pipe, cloudpickle.dumps(cls),
145+
target=functools.partial(_run, child_pipe, pickle_func(cls),
145146
parse_argv, *args, **kwargs))
146147
# lock for the msgid -> reply future map. The map will be set to None
147148
# when we stop.
@@ -254,14 +255,14 @@ def __dir__(self):
254255

255256

256257
def _create_local_worker_pool(worker_cls: 'type[worker.Worker]',
257-
count: int | None, parse_argv: bool, *args,
258-
**kwargs) -> worker.FixedWorkerPool:
258+
count: int | None, parse_argv: bool, pickle_func,
259+
*args, **kwargs) -> worker.FixedWorkerPool:
259260
"""Create a local worker pool for worker_cls."""
260261
if not count:
261262
count = _get_context().cpu_count()
262263
final_kwargs = worker.get_full_worker_args(worker_cls, **kwargs)
263264
stubs = [
264-
_make_stub(worker_cls, parse_argv, *args, **final_kwargs)
265+
_make_stub(worker_cls, parse_argv, pickle_func, *args, **final_kwargs)
265266
for _ in range(count)
266267
]
267268
return worker.FixedWorkerPool(workers=stubs, worker_concurrency=16)
@@ -283,13 +284,15 @@ class LocalWorkerPoolManager(AbstractContextManager):
283284

284285
def __init__(self,
285286
worker_class: 'type[worker.Worker]',
287+
pickle_func=pickle.dumps,
286288
*,
287289
count: int | None,
288290
worker_args: tuple = (),
289291
worker_kwargs: dict | None = None):
290292
worker_kwargs = {} if worker_kwargs is None else worker_kwargs
291293
self._pool = _create_local_worker_pool(worker_class, count, True,
292-
*worker_args, **worker_kwargs)
294+
pickle_func, *worker_args,
295+
**worker_kwargs)
293296

294297
def __enter__(self) -> worker.FixedWorkerPool:
295298
return self._pool

compiler_opt/distributed/local/local_worker_manager_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from compiler_opt.distributed.local import local_worker_manager
2525
from tf_agents.system import system_multiprocessing as multiprocessing
2626

27-
flags.DEFINE_integer(
27+
_TEST_FLAG = flags.DEFINE_integer(
2828
'test_only_flag', default=1, help='A flag used by some tests.')
2929

3030

@@ -74,7 +74,7 @@ def method(self):
7474
class JobGetFlags(Worker):
7575

7676
def method(self):
77-
return {'argv': sys.argv, 'the_flag': flags.FLAGS.test_only_flag}
77+
return {'argv': sys.argv, 'the_flag': _TEST_FLAG.value}
7878

7979

8080
class LocalWorkerManagerTest(absltest.TestCase):

compiler_opt/es/blackbox_learner_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import os
1717
from absl.testing import absltest
18+
import cloudpickle
1819
import gin
1920
import tempfile
2021
import numpy as np
@@ -148,6 +149,7 @@ def test_run_step(self):
148149
with local_worker_manager.LocalWorkerPoolManager(
149150
blackbox_test_utils.ESWorker,
150151
count=3,
152+
pickle_func=cloudpickle.dumps,
151153
worker_args=('',),
152154
worker_kwargs=dict(kwarg='')) as pool:
153155
self._learner.run_step(pool) # pylint: disable=protected-access

0 commit comments

Comments
 (0)