Skip to content

Commit 5ce1a96

Browse files
committed
move bcast op into pass
1 parent b681537 commit 5ce1a96

File tree

12 files changed

+82
-30
lines changed

12 files changed

+82
-30
lines changed

benchmark/fluid/args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,5 +140,11 @@ def parse_args():
140140
'--use_lars',
141141
action='store_true',
142142
help='If set, use lars for optimizers, ONLY support resnet module.')
143+
parser.add_argument(
144+
'--reduce_strategy',
145+
type=str,
146+
choices=['reduce', 'all_reduce'],
147+
default='all_reduce',
148+
help='Specify the reduce strategy, can be reduce, all_reduce')
143149
args = parser.parse_args()
144150
return args

benchmark/fluid/fluid_benchmark.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
170170
strategy = fluid.ExecutionStrategy()
171171
strategy.num_threads = args.cpus
172172
strategy.allow_op_delay = False
173+
build_strategy = fluid.BuildStrategy()
174+
if args.reduce_strategy == "reduce":
175+
build_strategy.reduce_strategy = fluid.BuildStrategy(
176+
).ReduceStrategy.Reduce
177+
else:
178+
build_strategy.reduce_strategy = fluid.BuildStrategy(
179+
).ReduceStrategy.AllReduce
180+
173181
avg_loss = train_args[0]
174182

175183
if args.update_method == "pserver":
@@ -184,6 +192,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
184192
avg_loss.name,
185193
main_program=train_prog,
186194
exec_strategy=strategy,
195+
build_strategy=build_strategy,
187196
num_trainers=num_trainers,
188197
trainer_id=trainer_id)
189198

benchmark/fluid/models/mnist.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,14 @@ def cnn_model(data):
6767

6868
def get_model(args, is_train, main_prog, startup_prog):
6969
# NOTE: mnist is small, we don't implement data sharding yet.
70-
filelist = [
71-
os.path.join(args.data_path, f) for f in os.listdir(args.data_path)
72-
]
70+
opt = None
71+
data_file_handle = None
7372
with fluid.program_guard(main_prog, startup_prog):
7473
if args.use_reader_op:
74+
filelist = [
75+
os.path.join(args.data_path, f)
76+
for f in os.listdir(args.data_path)
77+
]
7578
data_file_handle = fluid.layers.open_files(
7679
filenames=filelist,
7780
shapes=[[-1, 1, 28, 28], (-1, 1)],
@@ -100,7 +103,7 @@ def get_model(args, is_train, main_prog, startup_prog):
100103
if is_train:
101104
opt = fluid.optimizer.AdamOptimizer(
102105
learning_rate=0.001, beta1=0.9, beta2=0.999)
103-
opt.minimize()
106+
opt.minimize(avg_cost)
104107
if args.memory_optimize:
105108
fluid.memory_optimize(main_prog)
106109

paddle/fluid/framework/details/all_reduce_op_handle.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
4646
#endif
4747

4848
void AllReduceOpHandle::RunImpl() {
49-
platform::RecordEvent r("all_reduce", nullptr);
49+
if (dev_ctxes_.size() > 0UL) {
50+
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
51+
} else {
52+
platform::RecordEvent record_event(Name(), nullptr);
53+
}
54+
5055
if (NoDummyInputSize() == 1) {
5156
return; // No need to all reduce when GPU count = 1;
5257
} else {

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
1616
#include "paddle/fluid/framework/details/container_cast.h"
1717
#include "paddle/fluid/framework/details/variable_visitor.h"
18+
#include "paddle/fluid/platform/profiler.h"
1819

1920
namespace paddle {
2021
namespace framework {
2122
namespace details {
2223

2324
void BroadcastOpHandle::RunImpl() {
25+
if (dev_ctxes_.size() > 0UL) {
26+
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
27+
} else {
28+
platform::RecordEvent record_event(Name(), nullptr);
29+
}
30+
2431
if (places_.size() == 1) return;
2532

2633
// The input and output may have dummy vars.

paddle/fluid/framework/details/data_balance_op_handle.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
1616
#include <algorithm>
1717
#include "paddle/fluid/framework/details/container_cast.h"
18+
#include "paddle/fluid/platform/profiler.h"
1819

1920
namespace paddle {
2021
namespace framework {
@@ -86,6 +87,12 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
8687
}
8788

8889
void DataBalanceOpHandle::RunImpl() {
90+
if (dev_ctxes_.size() > 0UL) {
91+
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
92+
} else {
93+
platform::RecordEvent record_event(Name(), nullptr);
94+
}
95+
8996
PADDLE_ENFORCE_GT(places_.size(), 1,
9097
"Data balance can only be enabled when the number of "
9198
"places to run larger than 1.");

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -348,14 +348,31 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
348348

349349
size_t cur_device_id = 0;
350350
bool is_forwarding = true;
351+
bool is_dist_train = false;
351352

352353
for (ir::Node *node : sorted_ops) {
353354
if (boost::get<int>(
354355
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
355356
static_cast<int>(OpRole::kRPC)) {
356-
CreateRPCOp(&result, node);
357+
int op_dev_id = CreateRPCOp(&result, node);
358+
PADDLE_ENFORCE(op_dev_id != -1,
359+
"Can not schedule the RPC operator to the right place.");
360+
if (node->Op()->Type() == "recv") {
361+
auto recv_vars_attr =
362+
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
363+
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
364+
PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient]
365+
if (recv_vars_attr[0].find(".block") == std::string::npos) {
366+
bcast_var_name_set[op_dev_id].emplace(recv_vars_attr[0]);
367+
}
368+
}
369+
is_dist_train = true;
357370
} else if (IsDistTrainOp(node, send_vars, recv_vars)) {
358-
CreateDistTrainOp(&result, node);
371+
int op_dev_id = CreateDistTrainOp(&result, node);
372+
if (node->Op()->Type() == "concat") {
373+
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
374+
bcast_var_name_set[op_dev_id].emplace(origin_param_name);
375+
}
359376
} else if (IsScaleLossOp(node)) {
360377
// user can customize loss@grad if not use_default_grad_scale_
361378
if (strategy_.gradient_scale_ !=
@@ -414,7 +431,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
414431
CreateReduceOp(&result, g_name, cur_device_id);
415432
graph->Get<ShardedVarDevice>(kShardedVarDevice)
416433
.emplace(g_name, cur_device_id);
417-
bcast_var_name_set[cur_device_id].emplace(p_name);
434+
if (!is_dist_train) {
435+
// will send gradients directly when distributed training
436+
bcast_var_name_set[cur_device_id].emplace(p_name);
437+
}
418438
break;
419439
case BuildStrategy::ReduceStrategy::kAllReduce:
420440
if (IsSparseGradient(g_name)) {
@@ -436,14 +456,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
436456
}
437457
}
438458
}
439-
440459
bool use_gpu = false;
441460
#ifdef PADDLE_WITH_CUDA
442461
use_gpu = nccl_ctxs_ != nullptr;
443462
#endif
444463

445-
if (use_gpu ||
446-
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
464+
if ((use_gpu &&
465+
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
466+
is_dist_train) {
447467
// Insert BCast Ops
448468
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
449469
auto &to_bcast_set = bcast_var_name_set[dev_id];
@@ -676,8 +696,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
676696
return var;
677697
}
678698

679-
void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
680-
ir::Node *node) const {
699+
int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
700+
ir::Node *node) const {
681701
int op_dev_id = -1;
682702
std::vector<std::string> input_var_names;
683703
std::vector<std::string> output_var_names;
@@ -720,6 +740,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
720740
node->Op()->Type());
721741

722742
CreateComputationalOp(result, node, op_dev_id);
743+
return op_dev_id;
723744
}
724745

725746
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
@@ -738,8 +759,8 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
738759
}
739760

740761
// Create RPC related op handles that connects its in ops and out ops.
741-
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
742-
ir::Node *node) const {
762+
int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
763+
ir::Node *node) const {
743764
int op_dev_id = -1;
744765
if (node->Op()->Type() == "send") {
745766
// TODO(paddle-dev): getting the first var is not safe.
@@ -825,6 +846,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
825846
CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id);
826847
}
827848
}
849+
return op_dev_id;
828850
}
829851

830852
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {

paddle/fluid/framework/details/multi_devices_graph_pass.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
5454

5555
bool IsScaleLossOp(ir::Node *node) const;
5656

57-
void CreateRPCOp(ir::Graph *result, ir::Node *node) const;
58-
void CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
57+
int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
58+
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
5959

6060
/**
6161
* Is this operator as the end-point operator before/after send operator.

paddle/fluid/framework/details/reduce_op_handle.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ namespace framework {
2727
namespace details {
2828

2929
void ReduceOpHandle::RunImpl() {
30-
platform::RecordEvent r("reduce", nullptr);
30+
if (dev_ctxes_.size() > 0UL) {
31+
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
32+
} else {
33+
platform::RecordEvent record_event(Name(), nullptr);
34+
}
3135
if (places_.size() == 1) return;
3236
// the input and output may have dummy var.
3337
auto in_var_handles = DynamicCast<VarHandle>(inputs_);

paddle/fluid/framework/details/scale_loss_grad_op_handle.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() {
5151
->stream();
5252
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
5353
platform::CPUPlace(), &coeff_, sizeof(float), stream);
54-
VLOG(1) << place_ << "RUN Scale loss grad op";
54+
VLOG(10) << place_ << "RUN Scale loss grad op";
5555
});
5656
#endif
5757
}

0 commit comments

Comments
 (0)