Skip to content

Commit 7719349

Browse files
authored
Merge pull request #14122 from typhoonzero/cherrypick_14072_13766
cherry pick 13766
2 parents 863f80e + cb27415 commit 7719349

34 files changed

+1300
-106
lines changed

benchmark/fluid/args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,5 +142,10 @@ def parse_args():
142142
choices=['reduce', 'all_reduce'],
143143
default='all_reduce',
144144
help='Specify the reduce strategy, can be reduce, all_reduce')
145+
parser.add_argument(
146+
'--fuse_broadcast_op',
147+
action='store_true',
148+
help='If set, would fuse multiple broadcast operators into one fused_broadcast operator.'
149+
)
145150
args = parser.parse_args()
146151
return args

benchmark/fluid/fluid_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
177177
else:
178178
build_strategy.reduce_strategy = fluid.BuildStrategy(
179179
).ReduceStrategy.AllReduce
180+
build_strategy.fuse_broadcast_op = args.fuse_broadcast_op
180181

181182
avg_loss = train_args[0]
182183

@@ -240,7 +241,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
240241

241242
if args.use_fake_data or args.use_reader_op:
242243
try:
243-
244244
fetch_ret = exe.run(fetch_list)
245245
except fluid.core.EOFException as eof:
246246
break

paddle/fluid/API.spec

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ paddle.fluid.optimizer.ModelAverage.__init__ ArgSpec(args=['self', 'average_wind
355355
paddle.fluid.optimizer.ModelAverage.apply ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
356356
paddle.fluid.optimizer.ModelAverage.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
357357
paddle.fluid.optimizer.ModelAverage.restore ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None)
358+
paddle.fluid.optimizer.LarsMomentumOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'momentum', 'lars_coeff', 'lars_weight_decay', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.0005, None, None))
359+
paddle.fluid.optimizer.LarsMomentumOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
358360
paddle.fluid.backward.append_backward ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None))
359361
paddle.fluid.regularizer.L1DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,))
360362
paddle.fluid.regularizer.L2DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,))

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ if(WITH_GPU)
1616
dynload_cuda variable_visitor)
1717
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
1818
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
19+
nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
1920

2021
else()
2122
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
2223
variable_visitor)
2324
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
2425
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
26+
cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
2527
endif()
2628

2729
cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor)
@@ -34,7 +36,7 @@ if(WITH_GPU)
3436
endif()
3537

3638
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
37-
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
39+
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
3840

3941
if(WITH_GPU)
4042
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass)
@@ -58,4 +60,4 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
5860
cc_library(build_strategy SRCS build_strategy.cc DEPS
5961
graph_viz_pass multi_devices_graph_pass
6062
multi_devices_graph_print_pass multi_devices_graph_check_pass
61-
fuse_elewise_add_act_pass)
63+
fuse_elewise_add_act_pass multi_batch_merge_pass)

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,27 @@ void BroadcastOpHandle::RunImpl() {
4848
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
4949
}
5050

51+
BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes);
52+
}
53+
54+
void BroadcastOpHandle::BroadcastOneVar(
55+
const VarHandle &in_var_handle,
56+
const std::vector<VarHandle *> &out_var_handles,
57+
const std::vector<const Scope *> &var_scopes) {
5158
auto *in_var =
52-
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
59+
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
5360
PADDLE_ENFORCE_NOT_NULL(in_var);
5461
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
5562
if (UNLIKELY(!in_tensor.IsInitialized())) {
5663
VLOG(3) << "in var " << in_var_handle.name_ << "not inited, return!";
5764
return;
5865
}
5966

60-
InitOutputValue(*in_var_handle, out_var_handles);
67+
InitOutputValue(in_var_handle, out_var_handles);
6168

6269
if (platform::is_cpu_place(in_tensor.place())) {
6370
for (auto *out_var_handle : out_var_handles) {
64-
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
71+
if (out_var_handle->IsTheSameVar(in_var_handle)) {
6572
continue;
6673
}
6774
auto &out_p = out_var_handle->place_;
@@ -118,12 +125,12 @@ void BroadcastOpHandle::RunImpl() {
118125
}
119126
}
120127

121-
if (!out_handle->IsTheSameVar(*in_var_handle)) {
122-
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
128+
if (!out_handle->IsTheSameVar(in_var_handle)) {
129+
auto out_var = var_scopes.at(in_var_handle.scope_idx_)
123130
->FindVar(out_var_handles[0]->name_);
124131
paddle::framework::TensorCopy(
125-
in_tensor, in_var_handle->place_,
126-
*(dev_ctxes_.at(in_var_handle->place_)),
132+
in_tensor, in_var_handle.place_,
133+
*(dev_ctxes_.at(in_var_handle.place_)),
127134
&VariableVisitor::GetMutableTensor(out_var));
128135
}
129136
});

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ struct BroadcastOpHandle : public OpHandleBase {
6161
protected:
6262
void RunImpl() override;
6363

64-
private:
64+
void BroadcastOneVar(const VarHandle &in_var_handle,
65+
const std::vector<VarHandle *> &out_var_handles,
66+
const std::vector<const Scope *> &var_scopes);
67+
6568
std::vector<Scope *> local_scopes_;
6669
std::vector<platform::Place> places_;
6770
#ifdef PADDLE_WITH_CUDA

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
121121

122122
USE_PASS(fuse_elewise_add_act_pass);
123123
USE_PASS(graph_viz_pass);
124+
USE_PASS(multi_batch_merge_pass);
124125
USE_PASS(multi_devices_pass);
125126
USE_PASS(multi_devices_check_pass);
126127
USE_PASS(multi_devices_print_pass);

paddle/fluid/framework/details/build_strategy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ struct BuildStrategy {
6969

7070
bool enable_data_balance_{false};
7171

72+
bool fuse_broadcast_op_{false};
73+
7274
// User normally doesn't need to call this API.
7375
// The PassBuilder allows for more customized insert, remove of passes
7476
// from python side.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
16+
#include "paddle/fluid/framework/details/container_cast.h"
17+
#include "paddle/fluid/framework/details/variable_visitor.h"
18+
#include "paddle/fluid/platform/profiler.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace details {
23+
24+
void FusedBroadcastOpHandle::RunImpl() {
25+
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
26+
27+
if (places_.size() == 1UL) return;
28+
29+
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
30+
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
31+
32+
WaitInputVarGenerated();
33+
34+
std::vector<const Scope *> var_scopes;
35+
for (auto *s : local_scopes_) {
36+
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
37+
}
38+
39+
size_t place_num = places_.size();
40+
PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size());
41+
42+
for (size_t i = 0; i < in_var_handles.size(); ++i) {
43+
BroadcastOneVar(
44+
*in_var_handles[i],
45+
std::vector<VarHandle *>(out_var_handles.begin() + i * place_num,
46+
out_var_handles.begin() + (i + 1) * place_num),
47+
var_scopes);
48+
}
49+
}
50+
51+
std::string FusedBroadcastOpHandle::Name() const { return "fused_broadcast"; }
52+
53+
} // namespace details
54+
} // namespace framework
55+
} // namespace paddle
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <map>
18+
#include <string>
19+
#include <vector>
20+
21+
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
22+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
23+
#include "paddle/fluid/framework/lod_tensor.h"
24+
#include "paddle/fluid/framework/scope.h"
25+
#include "paddle/fluid/framework/selected_rows.h"
26+
#include "paddle/fluid/platform/device_context.h"
27+
28+
#ifdef PADDLE_WITH_CUDA
29+
#include "paddle/fluid/platform/nccl_helper.h"
30+
#endif
31+
32+
namespace paddle {
33+
namespace framework {
34+
namespace details {
35+
36+
struct FusedBroadcastOpHandle : public BroadcastOpHandle {
37+
public:
38+
#ifdef PADDLE_WITH_CUDA
39+
FusedBroadcastOpHandle(ir::Node *node,
40+
const std::vector<Scope *> local_scopes,
41+
const std::vector<platform::Place> &places,
42+
const platform::NCCLContextMap *nccl_ctx)
43+
: BroadcastOpHandle(node, local_scopes, places, nccl_ctx) {}
44+
#else
45+
FusedBroadcastOpHandle(ir::Node* node, const std::vector<Scope*> local_scopes,
46+
const std::vector<platform::Place>& places)
47+
: BroadcastOpHandle(node, local_scopes, places) {}
48+
#endif
49+
std::string Name() const override;
50+
51+
protected:
52+
void RunImpl() override;
53+
};
54+
55+
} // namespace details
56+
} // namespace framework
57+
} // namespace paddle

0 commit comments

Comments
 (0)