Skip to content

Commit 4359f6b

Browse files
authored
Fix multiprocess dataset map (#1511)
* fix multiprocess in dataset.map * minor fix
1 parent 7fb53ae commit 4359f6b

File tree

1 file changed

+45
-44
lines changed

1 file changed

+45
-44
lines changed

paddlenlp/datasets/dataset.py

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -199,28 +199,30 @@ def filter(self, fn, num_workers=0):
199199
"""
200200
assert num_workers >= 0, "num_workers should be a non-negative value"
201201
if num_workers > 0:
202-
with Pool(num_workers, initargs=(RLock(), )) as pool:
203-
204-
def filter_shard(num_workers, index, fn):
205-
self.shard(
206-
num_shards=num_workers, index=index, contiguous=True)
207-
self._filter(fn=fn)
208-
return self
209-
210-
kwds_per_shard = [
211-
dict(
212-
num_workers=num_workers, index=rank, fn=fn)
213-
for rank in range(num_workers)
214-
]
215-
results = [
216-
pool.apply_async(
217-
filter_shard, kwds=kwds) for kwds in kwds_per_shard
218-
]
219-
transformed_shards = [r.get() for r in results]
202+
pool = Pool(
203+
num_workers, initargs=(RLock(), ), maxtasksperchild=1000)
204+
205+
def filter_shard(num_workers, index, fn):
206+
self.shard(num_shards=num_workers, index=index, contiguous=True)
207+
self._filter(fn=fn)
208+
return self
209+
210+
kwds_per_shard = [
211+
dict(
212+
num_workers=num_workers, index=rank, fn=fn)
213+
for rank in range(num_workers)
214+
]
215+
results = [
216+
pool.apply_async(
217+
filter_shard, kwds=kwds) for kwds in kwds_per_shard
218+
]
219+
transformed_shards = [r.get() for r in results]
220220

221-
self.new_data = []
222-
for i in range(num_workers):
223-
self.new_data += transformed_shards[i].new_data
221+
pool.close()
222+
pool.join()
223+
self.new_data = []
224+
for i in range(num_workers):
225+
self.new_data += transformed_shards[i].new_data
224226
return self
225227
else:
226228
return self._filter(fn)
@@ -291,31 +293,30 @@ def map(self, fn, lazy=True, batched=False, num_workers=0):
291293

292294
assert num_workers >= 0, "num_workers should be a non-negative value"
293295
if num_workers > 0:
294-
with Pool(num_workers, initargs=(RLock(), )) as pool:
295-
296-
def map_shard(num_workers, index, fn, batched):
297-
self.shard(
298-
num_shards=num_workers, index=index, contiguous=True)
299-
self._map(fn=fn, lazy=False, batched=batched)
300-
return self
301-
302-
kwds_per_shard = [
303-
dict(
304-
num_workers=num_workers,
305-
index=rank,
306-
fn=fn,
307-
batched=batched) for rank in range(num_workers)
308-
]
309-
results = [
310-
pool.apply_async(
311-
map_shard, kwds=kwds) for kwds in kwds_per_shard
312-
]
313-
transformed_shards = [r.get() for r in results]
314296

315-
self.new_data = []
316-
for i in range(num_workers):
317-
self.new_data += transformed_shards[i].new_data
297+
def map_shard(num_workers, index, fn, batched):
298+
self.shard(num_shards=num_workers, index=index, contiguous=True)
299+
self._map(fn=fn, lazy=False, batched=batched)
300+
return self
301+
302+
kwds_per_shard = [
303+
dict(
304+
num_workers=num_workers, index=rank, fn=fn, batched=batched)
305+
for rank in range(num_workers)
306+
]
307+
pool = Pool(
308+
num_workers, initargs=(RLock(), ), maxtasksperchild=1000)
309+
results = [
310+
pool.apply_async(
311+
map_shard, kwds=kwds) for kwds in kwds_per_shard
312+
]
318313

314+
transformed_shards = [r.get() for r in results]
315+
pool.close()
316+
pool.join()
317+
self.new_data = []
318+
for i in range(num_workers):
319+
self.new_data += transformed_shards[i].new_data
319320
return self
320321
else:
321322
return self._map(fn, lazy=lazy, batched=batched)

0 commit comments

Comments
 (0)