Skip to content

Commit 31a316f

Browse files
authored
Merge pull request #425 from yinhaofeng/collective_train
collective_train
2 parents e56ccd4 + 845920b commit 31a316f

File tree

8 files changed

+126
-48
lines changed

8 files changed

+126
-48
lines changed

doc/collective_mode.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Collective模式运行
2+
如果您希望可以同时使用多张GPU,更为快速的训练您的模型,可以尝试使用`单机多卡``多机多卡`模式运行。
3+
4+
## 版本要求
5+
用户需要确保已经安装paddlepaddle-2.0.0-rc-gpu及以上版本的飞桨开源框架
6+
7+
## 设置config.yaml
8+
首先需要在模型的yaml配置中,加入use_fleet参数,并把值设置成True。
9+
同时设置use_gpu为True
10+
```yaml
11+
runner:
12+
# 通用配置不再赘述
13+
...
14+
# use fleet
15+
use_fleet: True
16+
```
17+
## 单机多卡训练
18+
19+
### 单机多卡模式下指定需要使用的卡号
20+
在没有进行设置的情况下将使用单机上所有gpu卡。若需要指定部分gpu卡执行,可以通过设置环境变量CUDA_VISIBLE_DEVICES来实现。
21+
例如单机上有8张卡,只打算用前4卡张训练,可以设置export CUDA_VISIBLE_DEVICES=0,1,2,3
22+
再执行训练脚本即可。
23+
24+
### 执行训练
25+
```bash
26+
# 动态图执行训练
27+
python -m paddle.distributed.launch ../../../tools/trainer.py -m config.yaml
28+
# 静态图执行训练
29+
python -m paddle.distributed.launch ../../../tools/static_trainer.py -m config.yaml
30+
```
31+
32+
注意:在使用静态图训练时,确保模型static_model.py程序中create_optimizer函数设置了分布式优化器。
33+
```python
34+
def create_optimizer(self, strategy=None):
35+
optimizer = paddle.optimizer.Adam(learning_rate=self.learning_rate, lazy_mode=True)
36+
# 通过Fleet API获取分布式优化器,将参数传入飞桨的基础优化器
37+
if strategy != None:
38+
import paddle.distributed.fleet as fleet
39+
optimizer = fleet.distributed_optimizer(optimizer, strategy)
40+
optimizer.minimize(self._cost)
41+
```
42+
43+
## 多机多卡训练
44+
使用多机多卡训练,您需要另外一台或多台能够互相ping通的机器。每台机器中都需要安装paddlepaddle-2.0.0-rc-gpu及以上版本的飞桨开源框架,同时将需要运行的paddlerec模型,数据集复制到每一台机器上。
45+
- 首先确保各个节点之间是联通的,相互之间通过IP可访问
46+
- 在每个节点上都需要持有代码与数据
47+
- 在每个节点上执行命令
48+
从单机多卡到多机多卡训练,在代码上不需要做任何改动,只需再额外指定ips参数即可。其内容为多机的ip列表,命令如下所示:
49+
```bash
50+
# 动态图
51+
# 动态图执行训练
52+
python -m paddle.distributed.launch --ips="xx.xx.xx.xx,yy.yy.yy.yy" --gpus 0,1,2,3,4,5,6,7 ../../../tools/trainer.py -m config.yaml
53+
# 静态图执行训练
54+
python -m paddle.distributed.launch --ips="xx.xx.xx.xx,yy.yy.yy.yy" --gpus 0,1,2,3,4,5,6,7 ../../../tools/static_trainer.py -m config.yaml
55+
```
56+
57+
## 修改reader
58+
目前我们paddlerec模型默认使用的reader都是继承自paddle.io.IterableDataset,在reader的__iter__函数中拆分文件,按行处理数据。当 paddle.io.DataLoader 中 num_workers > 0 时,每个子进程都会遍历全量的数据集返回全量样本,所以数据集会重复 num_workers 次,也就是每张卡都会获得全部的数据。您在训练时可能需要调整学习率等参数以保证训练效果。
59+
如果需要数据集样本不会重复,可通过 [paddle.io.get_worker_info](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/dataloader/dataloader_iter/get_worker_info_cn.html#get-worker-info) 获取各子进程的信息。并在 __iter__ 函数中划分各子进程的数据。[paddle.io.IterableDataset](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/dataloader/dataset/IterableDataset_cn.html#iterabledataset)的相关信息以及划分数据的示例可以点击这里获取。

doc/yaml.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
| use_inference | bool | True/False || 是否使用save_inference_model接口保存 |
2525
| save_inference_feed_varnames | list[string] | 组网中指定Variable的name || 预测模型的入口变量name |
2626
| save_inference_fetch_varnames | list[string] | 组网中指定Variable的name || 预测模型的出口变量name |
27+
| use_fleet | bool | True/False || 指定是否使用分布式运行单机多卡或多机多卡 |
2728

2829

2930
## hyper_parameters变量

models/rank/wide_deep/config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ runner:
1919
train_reader_path: "criteo_reader" # importlib format
2020
use_gpu: False
2121
use_auc: True
22-
train_batch_size: 2
22+
train_batch_size: 50
2323
epochs: 3
2424
print_interval: 2
2525
# model_init_path: "output_model_wide_deep/2" # init model
@@ -34,6 +34,8 @@ runner:
3434
use_inference: False
3535
save_inference_feed_varnames: ["C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"]
3636
save_inference_fetch_varnames: ["sigmoid_0.tmp_0"]
37+
#use fleet
38+
use_fleet: False
3739

3840
# hyper parameters of user-defined network
3941
hyper_parameters:

tools/infer.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
#
16-
# Licensed under the Apache License, Version 2.0 (the "License");
17-
# you may not use this file except in compliance with the License.
18-
# You may obtain a copy of the License at
19-
#
20-
# http://www.apache.org/licenses/LICENSE-2.0
21-
#
22-
# Unless required by applicable law or agreed to in writing, software
23-
# distributed under the License is distributed on an "AS IS" BASIS,
24-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25-
# See the License for the specific language governing permissions and
26-
# limitations under the License.
27-
2815
import paddle
2916
import os
3017
import paddle.nn as nn
@@ -68,6 +55,10 @@ def main(args):
6855
for parameter in args.opt:
6956
parameter = parameter.strip()
7057
key, value = parameter.split("=")
58+
if type(config.get(key)) is int:
59+
value = int(value)
60+
if type(config.get(key)) is bool:
61+
value = (True if value.lower() == "true" else False)
7162
config[key] = value
7263

7364
# tools.vars

tools/static_infer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def main(args):
5353
for parameter in args.opt:
5454
parameter = parameter.strip()
5555
key, value = parameter.split("=")
56+
if type(config.get(key)) is int:
57+
value = int(value)
58+
if type(config.get(key)) is bool:
59+
value = (True if value.lower() == "true" else False)
5660
config[key] = value
5761
# load static model class
5862
static_model_class = load_static_model_class(config)

tools/static_trainer.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def main(args):
5555
for parameter in args.opt:
5656
parameter = parameter.strip()
5757
key, value = parameter.split("=")
58+
if type(config.get(key)) is int:
59+
value = int(value)
60+
if type(config.get(key)) is bool:
61+
value = (True if value.lower() == "true" else False)
5862
config[key] = value
5963
# load static model class
6064
static_model_class = load_static_model_class(config)
@@ -63,9 +67,9 @@ def main(args):
6367
input_data_names = [data.name for data in input_data]
6468

6569
fetch_vars = static_model_class.net(input_data)
70+
6671
#infer_target_var = model.infer_target_var
6772
logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
68-
static_model_class.create_optimizer()
6973

7074
use_gpu = config.get("runner.use_gpu", True)
7175
use_auc = config.get("runner.use_auc", False)
@@ -79,6 +83,7 @@ def main(args):
7983
model_init_path = config.get("runner.model_init_path", None)
8084
batch_size = config.get("runner.train_batch_size", None)
8185
reader_type = config.get("runner.reader_type", "DataLoader")
86+
use_fleet = config.get("runner.use_fleet", False)
8287
os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1))
8388
logger.info("**************common.configs**********")
8489
logger.info(
@@ -88,6 +93,16 @@ def main(args):
8893
logger.info("**************common.configs**********")
8994

9095
place = paddle.set_device('gpu' if use_gpu else 'cpu')
96+
97+
if use_fleet:
98+
from paddle.distributed import fleet
99+
strategy = fleet.DistributedStrategy()
100+
fleet.init(is_collective=True, strategy=strategy)
101+
if use_fleet:
102+
static_model_class.create_optimizer(strategy)
103+
else:
104+
static_model_class.create_optimizer()
105+
91106
exe = paddle.static.Executor(place)
92107
# initialize
93108
exe.run(paddle.static.default_startup_program())
@@ -132,11 +147,20 @@ def main(args):
132147
else:
133148
logger.info("reader type wrong")
134149

135-
save_static_model(
136-
paddle.static.default_main_program(),
137-
model_save_path,
138-
epoch_id,
139-
prefix='rec_static')
150+
if use_fleet:
151+
trainer_id = paddle.distributed.get_rank()
152+
if trainer_id == 0:
153+
save_static_model(
154+
paddle.static.default_main_program(),
155+
model_save_path,
156+
epoch_id,
157+
prefix='rec_static')
158+
else:
159+
save_static_model(
160+
paddle.static.default_main_program(),
161+
model_save_path,
162+
epoch_id,
163+
prefix='rec_static')
140164

141165
if use_inference:
142166
feed_var_names = config.get("runner.save_inference_feed_varnames",

tools/to_static.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
#
16-
# Licensed under the Apache License, Version 2.0 (the "License");
17-
# you may not use this file except in compliance with the License.
18-
# You may obtain a copy of the License at
19-
#
20-
# http://www.apache.org/licenses/LICENSE-2.0
21-
#
22-
# Unless required by applicable law or agreed to in writing, software
23-
# distributed under the License is distributed on an "AS IS" BASIS,
24-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25-
# See the License for the specific language governing permissions and
26-
# limitations under the License.
27-
2815
import paddle
2916
import os
3017
import paddle.nn as nn

tools/trainer.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
#
16-
# Licensed under the Apache License, Version 2.0 (the "License");
17-
# you may not use this file except in compliance with the License.
18-
# You may obtain a copy of the License at
19-
#
20-
# http://www.apache.org/licenses/LICENSE-2.0
21-
#
22-
# Unless required by applicable law or agreed to in writing, software
23-
# distributed under the License is distributed on an "AS IS" BASIS,
24-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25-
# See the License for the specific language governing permissions and
26-
# limitations under the License.
27-
2815
import paddle
2916
import os
3017
import paddle.nn as nn
@@ -68,6 +55,10 @@ def main(args):
6855
for parameter in args.opt:
6956
parameter = parameter.strip()
7057
key, value = parameter.split("=")
58+
if type(config.get(key)) is int:
59+
value = int(value)
60+
if type(config.get(key)) is bool:
61+
value = (True if value.lower() == "true" else False)
7162
config[key] = value
7263

7364
# tools.vars
@@ -79,6 +70,7 @@ def main(args):
7970
train_batch_size = config.get("runner.train_batch_size", None)
8071
model_save_path = config.get("runner.model_save_path", "model_output")
8172
model_init_path = config.get("runner.model_init_path", None)
73+
use_fleet = config.get("runner.use_fleet", False)
8274

8375
logger.info("**************common.configs**********")
8476
logger.info(
@@ -102,6 +94,14 @@ def main(args):
10294
# to do : add optimizer function
10395
optimizer = dy_model_class.create_optimizer(dy_model, config)
10496

97+
# use fleet run collective
98+
if use_fleet:
99+
from paddle.distributed import fleet
100+
strategy = fleet.DistributedStrategy()
101+
fleet.init(is_collective=True, strategy=strategy)
102+
optimizer = fleet.distributed_optimizer(optimizer)
103+
dy_model = fleet.distributed_model(dy_model)
104+
105105
logger.info("read data")
106106
train_dataloader = create_data_loader(config=config, place=place)
107107

@@ -186,8 +186,18 @@ def main(args):
186186
tensor_print_str + " epoch time: {:.2f} s".format(
187187
time.time() - epoch_begin))
188188

189-
save_model(
190-
dy_model, optimizer, model_save_path, epoch_id, prefix='rec')
189+
if use_fleet:
190+
trainer_id = paddle.distributed.get_rank()
191+
if trainer_id == 0:
192+
save_model(
193+
dy_model,
194+
optimizer,
195+
model_save_path,
196+
epoch_id,
197+
prefix='rec')
198+
else:
199+
save_model(
200+
dy_model, optimizer, model_save_path, epoch_id, prefix='rec')
191201

192202

193203
if __name__ == '__main__':

0 commit comments

Comments
 (0)