6868import sys
6969import time
7070from collections import defaultdict , namedtuple
71+ from contextvars import ContextVar
72+ from multiprocessing .dummy import Pool
7173
7274import networkx as nx
7375from boltons .setutils import IndexedSet as iset
7880
7981log = logging .getLogger (__name__ )
8082
83+ thread_pool : ContextVar [Pool ] = ContextVar ("thread_pool" , default = Pool (7 ))
8184
8285if sys .version_info < (3 , 6 ):
8386 """
@@ -213,9 +216,7 @@ def _call_operation(self, op, solution):
213216 except Exception as ex :
214217 jetsam (ex , locals (), plan = "self" )
215218
216- def _execute_thread_pool_barrier_method (
217- self , solution , overwrites , executed , thread_pool_size = 10
218- ):
219+ def _execute_thread_pool_barrier_method (self , solution , overwrites , executed ):
219220 """
220221 This method runs the graph using a parallel pool of thread executors.
221222 You may achieve lower total latency if your graph is sufficiently
@@ -224,17 +225,12 @@ def _execute_thread_pool_barrier_method(
224225 :param solution:
225226 must contain the input values only, gets modified
226227 """
227- from multiprocessing .dummy import Pool
228-
229228 # Keep original inputs for pinning.
230229 pinned_values = {
231230 n : solution [n ] for n in self .steps if isinstance (n , _PinInstruction )
232231 }
233232
234- # if we have not already created a thread_pool, create one
235- if not hasattr (self .net , "_thread_pool" ):
236- self .net ._thread_pool = Pool (thread_pool_size )
237- pool = self .net ._thread_pool
233+ pool = thread_pool .get ()
238234
239235 # with each loop iteration, we determine a set of operations that can be
240236 # scheduled, then schedule them onto a thread pool, then collect their
@@ -288,7 +284,6 @@ def _execute_thread_pool_barrier_method(
288284 if len (upnext ) == 0 :
289285 break
290286
291- ## TODO: accept pool from caller
292287 done_iterator = pool .imap_unordered (
293288 (lambda op : (op , self ._call_operation (op , solution ))), upnext
294289 )
0 commit comments