28
28
"""
29
29
import concurrent .futures
30
30
import sys
31
- import cloudpickle
32
31
import dataclasses
33
32
import functools
34
33
import multiprocessing
34
+ import pickle
35
35
import psutil
36
36
import threading
37
37
@@ -80,7 +80,7 @@ def _run_impl(pipe: connection.Connection, worker_class: SerializedClass,
80
80
pool = concurrent .futures .ThreadPoolExecutor (max_workers = 1 )
81
81
# unpickle the type, which triggers all the necessary imports. This registers
82
82
# abseil flags, which can then be parsed.
83
- cls = cloudpickle .loads (worker_class )
83
+ cls = pickle .loads (worker_class )
84
84
if parse_argv :
85
85
flags .FLAGS (sys .argv , known_only = True )
86
86
obj = cls (* args , ** kwargs )
@@ -126,7 +126,8 @@ def _run(pipe: connection.Connection, worker_class: SerializedClass,
126
126
raise e
127
127
128
128
129
- def _make_stub (cls : 'type[worker.Worker]' , parse_argv : bool , * args , ** kwargs ):
129
+ def _make_stub (cls : 'type[worker.Worker]' , parse_argv : bool , pickle_func , * args ,
130
+ ** kwargs ):
130
131
131
132
class _Stub :
132
133
"""Client stub to a worker hosted by a process."""
@@ -141,7 +142,7 @@ def __init__(self):
141
142
# to handle high priority requests. The expectation is that the user
142
143
# achieves concurrency through multiprocessing, not multithreading.
143
144
self ._process = _get_context ().Process (
144
- target = functools .partial (_run , child_pipe , cloudpickle . dumps (cls ),
145
+ target = functools .partial (_run , child_pipe , pickle_func (cls ),
145
146
parse_argv , * args , ** kwargs ))
146
147
# lock for the msgid -> reply future map. The map will be set to None
147
148
# when we stop.
@@ -254,14 +255,14 @@ def __dir__(self):
254
255
255
256
256
257
def _create_local_worker_pool (worker_cls : 'type[worker.Worker]' ,
257
- count : int | None , parse_argv : bool , * args ,
258
- ** kwargs ) -> worker .FixedWorkerPool :
258
+ count : int | None , parse_argv : bool , pickle_func ,
259
+ * args , * *kwargs ) -> worker .FixedWorkerPool :
259
260
"""Create a local worker pool for worker_cls."""
260
261
if not count :
261
262
count = _get_context ().cpu_count ()
262
263
final_kwargs = worker .get_full_worker_args (worker_cls , ** kwargs )
263
264
stubs = [
264
- _make_stub (worker_cls , parse_argv , * args , ** final_kwargs )
265
+ _make_stub (worker_cls , parse_argv , pickle_func , * args , ** final_kwargs )
265
266
for _ in range (count )
266
267
]
267
268
return worker .FixedWorkerPool (workers = stubs , worker_concurrency = 16 )
@@ -283,13 +284,15 @@ class LocalWorkerPoolManager(AbstractContextManager):
283
284
284
285
def __init__ (self ,
285
286
worker_class : 'type[worker.Worker]' ,
287
+ pickle_func = pickle .dumps ,
286
288
* ,
287
289
count : int | None ,
288
290
worker_args : tuple = (),
289
291
worker_kwargs : dict | None = None ):
290
292
worker_kwargs = {} if worker_kwargs is None else worker_kwargs
291
293
self ._pool = _create_local_worker_pool (worker_class , count , True ,
292
- * worker_args , ** worker_kwargs )
294
+ pickle_func , * worker_args ,
295
+ ** worker_kwargs )
293
296
294
297
def __enter__ (self ) -> worker .FixedWorkerPool :
295
298
return self ._pool
0 commit comments