Skip to content

Commit 7890f05

Browse files
committed
add split dataset
1 parent c0e2e55 commit 7890f05

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

doc/collective_mode.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,4 @@ python -m paddle.distributed.launch --ips="xx.xx.xx.xx,yy.yy.yy.yy" --gpus 0,1,2
5656

5757
## 修改reader
5858
目前我们paddlerec模型默认使用的reader都是继承自paddle.io.IterableDataset,在reader的__iter__函数中拆分文件,按行处理数据。当 paddle.io.DataLoader 中 num_workers > 0 时,每个子进程都会遍历全量的数据集返回全量样本,所以数据集会重复 num_workers 次,也就是每张卡都会获得全部的数据。您在训练时可能需要调整学习率等参数以保证训练效果。
59-
如果需要数据集样本不会重复,可通过 [paddle.io.get_worker_info](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/dataloader/dataloader_iter/get_worker_info_cn.html#get-worker-info) 获取各子进程的信息。并在 __iter__ 函数中划分各子进程的数据[paddle.io.IterableDataset](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/dataloader/dataset/IterableDataset_cn.html#iterabledataset)的相关信息以及划分数据的示例可以点击这里获取。
59+
如果需要数据集样本不会重复,可通过paddle.distributed.get_rank()函数获取当前使用的第几张卡,paddle.distributed.get_world_size()函数获取使用的总卡数。并在reader文件中自行添加逻辑划分各子进程的数据[paddle.io.IterableDataset](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/dataloader/dataset/IterableDataset_cn.html#iterabledataset)的相关信息以及划分数据的示例可以点击这里获取。

models/rank/wide_deep/criteo_reader.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,29 @@
1414

1515
from __future__ import print_function
1616
import numpy as np
17-
17+
import paddle
1818
from paddle.io import IterableDataset
1919

2020

2121
class RecDataset(IterableDataset):
2222
def __init__(self, file_list, config):
2323
super(RecDataset, self).__init__()
2424
self.file_list = file_list
25+
use_fleet = config.get("runner.use_fleet", False)
26+
if use_fleet:
27+
worker_id = paddle.distributed.get_rank()
28+
worker_num = paddle.distributed.get_world_size()
29+
file_num = len(file_list)
30+
if file_num < worker_num:
31+
raise ValueError(
32+
"The number of data files is less than the number of workers"
33+
)
34+
blocksize = int(file_num / worker_num)
35+
self.file_list = file_list[worker_id * blocksize:(worker_id + 1) *
36+
blocksize]
37+
remainder = file_num - (blocksize * worker_num)
38+
if worker_id < remainder:
39+
self.file_list.append(file_list[-(worker_id + 1)])
2540
self.init()
2641

2742
def init(self):

0 commit comments

Comments
 (0)