@@ -199,28 +199,30 @@ def filter(self, fn, num_workers=0):
199
199
"""
200
200
assert num_workers >= 0 , "num_workers should be a non-negative value"
201
201
if num_workers > 0 :
202
- with Pool ( num_workers , initargs = ( RLock (), )) as pool :
203
-
204
- def filter_shard ( num_workers , index , fn ):
205
- self . shard (
206
- num_shards = num_workers , index = index , contiguous = True )
207
- self ._filter (fn = fn )
208
- return self
209
-
210
- kwds_per_shard = [
211
- dict (
212
- num_workers = num_workers , index = rank , fn = fn )
213
- for rank in range (num_workers )
214
- ]
215
- results = [
216
- pool .apply_async (
217
- filter_shard , kwds = kwds ) for kwds in kwds_per_shard
218
- ]
219
- transformed_shards = [r .get () for r in results ]
202
+ pool = Pool (
203
+ num_workers , initargs = ( RLock (), ), maxtasksperchild = 1000 )
204
+
205
+ def filter_shard ( num_workers , index , fn ):
206
+ self . shard ( num_shards = num_workers , index = index , contiguous = True )
207
+ self ._filter (fn = fn )
208
+ return self
209
+
210
+ kwds_per_shard = [
211
+ dict (
212
+ num_workers = num_workers , index = rank , fn = fn )
213
+ for rank in range (num_workers )
214
+ ]
215
+ results = [
216
+ pool .apply_async (
217
+ filter_shard , kwds = kwds ) for kwds in kwds_per_shard
218
+ ]
219
+ transformed_shards = [r .get () for r in results ]
220
220
221
- self .new_data = []
222
- for i in range (num_workers ):
223
- self .new_data += transformed_shards [i ].new_data
221
+ pool .close ()
222
+ pool .join ()
223
+ self .new_data = []
224
+ for i in range (num_workers ):
225
+ self .new_data += transformed_shards [i ].new_data
224
226
return self
225
227
else :
226
228
return self ._filter (fn )
@@ -291,31 +293,30 @@ def map(self, fn, lazy=True, batched=False, num_workers=0):
291
293
292
294
assert num_workers >= 0 , "num_workers should be a non-negative value"
293
295
if num_workers > 0 :
294
- with Pool (num_workers , initargs = (RLock (), )) as pool :
295
-
296
- def map_shard (num_workers , index , fn , batched ):
297
- self .shard (
298
- num_shards = num_workers , index = index , contiguous = True )
299
- self ._map (fn = fn , lazy = False , batched = batched )
300
- return self
301
-
302
- kwds_per_shard = [
303
- dict (
304
- num_workers = num_workers ,
305
- index = rank ,
306
- fn = fn ,
307
- batched = batched ) for rank in range (num_workers )
308
- ]
309
- results = [
310
- pool .apply_async (
311
- map_shard , kwds = kwds ) for kwds in kwds_per_shard
312
- ]
313
- transformed_shards = [r .get () for r in results ]
314
296
315
- self .new_data = []
316
- for i in range (num_workers ):
317
- self .new_data += transformed_shards [i ].new_data
297
+ def map_shard (num_workers , index , fn , batched ):
298
+ self .shard (num_shards = num_workers , index = index , contiguous = True )
299
+ self ._map (fn = fn , lazy = False , batched = batched )
300
+ return self
301
+
302
+ kwds_per_shard = [
303
+ dict (
304
+ num_workers = num_workers , index = rank , fn = fn , batched = batched )
305
+ for rank in range (num_workers )
306
+ ]
307
+ pool = Pool (
308
+ num_workers , initargs = (RLock (), ), maxtasksperchild = 1000 )
309
+ results = [
310
+ pool .apply_async (
311
+ map_shard , kwds = kwds ) for kwds in kwds_per_shard
312
+ ]
318
313
314
+ transformed_shards = [r .get () for r in results ]
315
+ pool .close ()
316
+ pool .join ()
317
+ self .new_data = []
318
+ for i in range (num_workers ):
319
+ self .new_data += transformed_shards [i ].new_data
319
320
return self
320
321
else :
321
322
return self ._map (fn , lazy = lazy , batched = batched )
0 commit comments