@@ -12,7 +12,15 @@ def main_process_print(self, *args, sep=' ', end='\n', file=None):
1212 print (self , * args , sep = sep , end = end , file = file )
1313
1414
15- def chunked_worker_run (map_func , args , results_queue = None ):
15+ def chunked_worker_run (map_func , args , results_queue = None , device_id = None ):
16+ if device_id is not None :
17+ try :
18+ import torch
19+ torch .cuda .set_device (device_id )
20+ if hasattr (map_func , '__self__' ) and map_func .__self__ is not None :
21+ map_func .__self__ .device = torch .device (f'cuda:{ device_id } ' )
22+ except Exception :
23+ traceback .print_exc ()
1624 for a in args :
1725 # noinspection PyBroadException
1826 try :
@@ -25,10 +33,15 @@ def chunked_worker_run(map_func, args, results_queue=None):
2533 results_queue .put (None )
2634
2735
28- def chunked_multiprocess_run (map_func , args , num_workers , q_max_size = 1000 ):
36+ def chunked_multiprocess_run (map_func , args , num_workers , q_max_size = 1000 , device_ids = None ):
2937 num_jobs = len (args )
3038 if num_jobs < num_workers :
3139 num_workers = num_jobs
40+ if device_ids is not None :
41+ device_ids = device_ids [:num_workers ]
42+
43+ if device_ids is not None :
44+ assert len (device_ids ) == num_workers
3245
3346 queues = [Manager ().Queue (maxsize = q_max_size // num_workers ) for _ in range (num_workers )]
3447 if platform .system ().lower () != 'windows' :
@@ -39,7 +52,9 @@ def chunked_multiprocess_run(map_func, args, num_workers, q_max_size=1000):
3952 workers = []
4053 for i in range (num_workers ):
4154 worker = process_creation_func (
42- target = chunked_worker_run , args = (map_func , args [i ::num_workers ], queues [i ]), daemon = True
55+ target = chunked_worker_run ,
56+ args = (map_func , args [i ::num_workers ], queues [i ], None if device_ids is None else device_ids [i ]),
57+ daemon = True
4358 )
4459 workers .append (worker )
4560 worker .start ()
0 commit comments