Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions dygraph/mnist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.base import to_variable

from paddle.distributed import fleet

def parse_args():
parser = argparse.ArgumentParser("Training for Mnist.")
Expand Down Expand Up @@ -174,8 +175,11 @@ def train_mnist(args):
epoch_num = args.epoch
BATCH_SIZE = 64

place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
if args.use_data_parallel:
place_idx = int(os.environ['FLAGS_selected_gpus'])
place = fluid.CUDAPlace(place_idx)
else:
place = fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
if args.ce:
print("ce mode")
Expand All @@ -184,12 +188,14 @@ def train_mnist(args):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed

if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
mnist = MNIST()
adam = AdamOptimizer(learning_rate=0.001, parameter_list=mnist.parameters())
if args.use_data_parallel:
mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy)
fleet.init(is_collective=True)
dist_strategy = fleet.DistributedStrategy()
adam = fleet.distributed_optimizer(adam, dist_strategy)
# call after distributed_optimizer so as to apply dist_strategy
mnist = fleet.build_distributed_model(mnist)

train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True)
Expand Down Expand Up @@ -241,7 +247,7 @@ def train_mnist(args):

save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
fleet.worker_index() == 0)
if save_parameters:
fluid.save_dygraph(mnist.state_dict(), "save_temp")

Expand Down
18 changes: 12 additions & 6 deletions dygraph/resnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import argparse
import ast
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
Expand All @@ -23,6 +24,7 @@

from paddle.fluid import framework

from paddle.distributed import fleet
import math
import sys
import time
Expand Down Expand Up @@ -283,8 +285,11 @@ def eval(model, data):

def train_resnet():
epoch = args.epoch
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
if args.use_data_parallel:
place_idx = int(os.environ['FLAGS_selected_gpus'])
place = fluid.CUDAPlace(place_idx)
else:
place = fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
if args.ce:
print("ce mode")
Expand All @@ -293,14 +298,15 @@ def train_resnet():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed

if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()

resnet = ResNet()
optimizer = optimizer_setting(parameter_list=resnet.parameters())

if args.use_data_parallel:
resnet = fluid.dygraph.parallel.DataParallel(resnet, strategy)
fleet.init(is_collective=True)
dist_strategy = fleet.DistributedStrategy()
optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
# call after distributed_optimizer so as to apply dist_strategy
resnet = fleet.build_distributed_model(resnet)

train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size)
Expand Down
15 changes: 9 additions & 6 deletions dygraph/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.distributed import fleet

from utils.configure import PDConfig
from utils.check import check_gpu, check_version
Expand All @@ -32,9 +33,9 @@

def do_train(args):
if args.use_cuda:
trainer_count = fluid.dygraph.parallel.Env().nranks
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id
) if trainer_count > 1 else fluid.CUDAPlace(0)
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", 1))
place_idx = int(os.getenv('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(place_idx)
else:
trainer_count = 1
place = fluid.CPUPlace()
Expand Down Expand Up @@ -130,9 +131,11 @@ def do_train(args):
transformer.load_dict(model_dict)

if trainer_count > 1:
strategy = fluid.dygraph.parallel.prepare_context()
transformer = fluid.dygraph.parallel.DataParallel(
transformer, strategy)
fleet.init(is_collective=True)
dist_strategy = fleet.DistributedStrategy()
optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
# call after distributed_optimizer so as to apply dist_strategy
transformer = fleet.build_distributed_model(transformer)

# the best cross-entropy value with label smoothing
loss_normalizer = -(
Expand Down