Skip to content

Commit 3232618

Browse files
authored
checkerrpick Make fuse_all_reduce_op_pass support mix_precision test=develop test=release (#18490)
1 parent 2410700 commit 3232618

File tree

8 files changed

+332
-152
lines changed

8 files changed

+332
-152
lines changed

paddle/fluid/framework/details/multi_devices_helper.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ constexpr char kFusedVarNamePrefix[] = "@FUSEDVAR@";
5858
typedef std::string FusedOptType;
5959
constexpr char kFusedOptType[] = "fused_opt_type";
6060

61-
typedef std::string FusedGrads;
61+
typedef std::vector<std::string> FusedGrads;
6262
constexpr char kFusedGrads[] = "fused_gradients";
6363

6464
typedef std::vector<std::pair<std::string, std::string>> ParamsAndGrads;
6565
constexpr char kParamsAndGrads[] = "params_grads";
6666

6767
typedef std::vector<std::vector<std::pair<std::string, std::string>>>
68-
GroupGradsAndParams;
69-
constexpr char kGroupGradsAndParams[] = "group_grads_params";
68+
GroupParamsAndGrads;
69+
constexpr char kGroupParamsAndGrads[] = "group_params_grads";
7070

7171
} // namespace details
7272
} // namespace framework

paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc

Lines changed: 170 additions & 113 deletions
Large diffs are not rendered by default.

paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,17 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
101101
"this pass.");
102102
}
103103
auto &fused_grad = result.Get<details::FusedGrads>(details::kFusedGrads);
104+
PADDLE_ENFORCE_NE(fused_grad.size(), 0,
105+
"The fused gradient should not be empty.");
106+
PADDLE_ENFORCE_EQ(fused_grad.size(), 1,
107+
"Because the dtype of those gradients "
108+
"is not unified, so the number of fused gradients is "
109+
"more than one, but it is not supported currently.");
104110
auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars);
105-
auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad);
111+
auto iter =
112+
std::find(fused_vars.begin(), fused_vars.end(), fused_grad.front());
106113
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad.");
107-
fused_vars_name[kGrad] = fused_grad;
114+
fused_vars_name[kGrad] = fused_grad.front();
108115

109116
// Sort the parameters and auxiliary variables according
110117
// to parameters' name to make variables' name correspond correctly.

paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,46 +30,24 @@ class FuseAllReduceOpPass : public ir::Pass {
3030
protected:
3131
void ApplyImpl(ir::Graph *graph) const override {
3232
ir::Graph &result = *graph;
33-
3433
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
3534
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
3635
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
3736
auto *multi_nccl_ctxs =
3837
&Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
3938
#endif
4039

41-
std::unordered_set<std::string> grads;
4240
auto &params_grads =
4341
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
4442
size_t num_of_all_reduce = params_grads.size();
43+
std::unordered_set<std::string> grads;
4544
grads.reserve(num_of_all_reduce);
4645
for (auto p_g : params_grads) {
4746
grads.insert(p_g.second);
4847
}
4948

50-
size_t num_place = places.size();
51-
std::unordered_map<std::string, ir::Node *> all_reduce_ops;
52-
all_reduce_ops.reserve(grads.size());
53-
for (auto &node : result.Nodes()) {
54-
if (node->IsOp()) {
55-
PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>());
56-
auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
57-
&node->Wrapper<details::OpHandleBase>());
58-
if (all_reduce_op_handle) {
59-
auto inputs = details::DynamicCast<details::VarHandle>(
60-
all_reduce_op_handle->Inputs());
61-
PADDLE_ENFORCE_EQ(inputs.size(), num_place);
62-
// The inputs' name should be the same.
63-
auto &grad_name = inputs[0]->name();
64-
for (size_t i = 1; i < inputs.size(); ++i) {
65-
PADDLE_ENFORCE_EQ(inputs[i]->name(), grad_name,
66-
"The input name should be the same.");
67-
}
68-
PADDLE_ENFORCE_NE(grads.count(grad_name), static_cast<size_t>(0));
69-
all_reduce_ops.emplace(grad_name, node);
70-
}
71-
}
72-
}
49+
std::unordered_map<std::string, Node *> all_reduce_ops =
50+
GetAllReduceOps(result, places, grads);
7351

7452
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
7553
if (all_reduce_ops.size() == 0) {
@@ -82,16 +60,16 @@ class FuseAllReduceOpPass : public ir::Pass {
8260
"it is not supported currently.");
8361
VLOG(10) << "Insert fused_all_reduce";
8462

85-
auto &group_grads_params =
86-
graph->Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
63+
auto &group_params_grads =
64+
graph->Get<details::GroupParamsAndGrads>(details::kGroupParamsAndGrads);
8765

88-
for (auto &group_g_p : group_grads_params) {
89-
size_t group_size = group_g_p.size();
66+
for (auto &group_p_g : group_params_grads) {
67+
size_t group_size = group_p_g.size();
9068
PADDLE_ENFORCE_GT(group_size, static_cast<size_t>(0));
9169
std::vector<ir::Node *> group_all_reduce_ops;
9270
group_all_reduce_ops.reserve(group_size);
93-
for (auto &g_p : group_g_p) {
94-
group_all_reduce_ops.emplace_back(all_reduce_ops.at(g_p.first));
71+
for (auto &p_g : group_p_g) {
72+
group_all_reduce_ops.emplace_back(all_reduce_ops.at(p_g.second));
9573
}
9674
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
9775
InsertFusedAllReduce(places, local_scopes, group_size,
@@ -103,6 +81,35 @@ class FuseAllReduceOpPass : public ir::Pass {
10381
}
10482
}
10583

84+
std::unordered_map<std::string, Node *> GetAllReduceOps(
85+
const Graph &result, const std::vector<platform::Place> &places,
86+
const std::unordered_set<std::string> &grads) const {
87+
size_t num_place = places.size();
88+
std::unordered_map<std::string, Node *> all_reduce_ops;
89+
all_reduce_ops.reserve(grads.size());
90+
for (auto &node : result.Nodes()) {
91+
if (node->IsOp()) {
92+
PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>());
93+
auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
94+
&node->Wrapper<details::OpHandleBase>());
95+
if (all_reduce_op_handle) {
96+
auto inputs = details::DynamicCast<details::VarHandle>(
97+
all_reduce_op_handle->Inputs());
98+
PADDLE_ENFORCE_EQ(inputs.size(), num_place);
99+
// The inputs' name should be the same.
100+
auto &grad_name = inputs[0]->name();
101+
for (size_t i = 1; i < inputs.size(); ++i) {
102+
PADDLE_ENFORCE_EQ(inputs[i]->name(), grad_name,
103+
"The input name should be the same.");
104+
}
105+
PADDLE_ENFORCE_NE(grads.count(grad_name), static_cast<size_t>(0));
106+
all_reduce_ops.emplace(grad_name, node);
107+
}
108+
}
109+
}
110+
return all_reduce_ops;
111+
}
112+
106113
void InsertFusedAllReduce(const std::vector<platform::Place> &places,
107114
const std::vector<Scope *> &local_scopes,
108115
const size_t num_of_all_reduce,

paddle/fluid/operators/alloc_continuous_space_op.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,11 @@ REGISTER_OPERATOR(alloc_continuous_space,
227227
paddle::operators::AllocContinuousSpaceOp,
228228
paddle::operators::AllocContinuousSpaceOpMaker);
229229
namespace ops = paddle::operators;
230+
namespace plat = paddle::platform;
230231
REGISTER_OP_CPU_KERNEL(
231232
alloc_continuous_space,
233+
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext,
234+
plat::float16>,
232235
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext, int>,
233236
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext, float>,
234237
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext,
@@ -237,6 +240,8 @@ REGISTER_OP_CPU_KERNEL(
237240
#ifdef PADDLE_WITH_CUDA
238241
REGISTER_OP_CUDA_KERNEL(
239242
alloc_continuous_space,
243+
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext,
244+
plat::float16>,
240245
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext, int>,
241246
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext, float>,
242247
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext,

paddle/fluid/operators/optimizers/sgd_op.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/optimizers/sgd_op.h"
16-
16+
#include <string>
1717
namespace paddle {
1818
namespace operators {
1919

@@ -46,6 +46,17 @@ class SGDOp : public framework::OperatorWithKernel {
4646
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
4747
return framework::OpKernelType(data_type, ctx.device_context());
4848
}
49+
50+
framework::OpKernelType GetKernelTypeForVar(
51+
const std::string &var_name, const framework::Tensor &tensor,
52+
const framework::OpKernelType &expected_kernel_type) const {
53+
if (var_name == "LearningRate") {
54+
return framework::OpKernelType(tensor.type(), tensor.place(),
55+
tensor.layout());
56+
}
57+
return framework::OpKernelType(expected_kernel_type.data_type_,
58+
tensor.place(), tensor.layout());
59+
}
4960
};
5061

5162
class SGDOpInferVarType : public framework::VarTypeInference {

paddle/fluid/operators/optimizers/sgd_op.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows,
4646
// Atomic Operation to avoid concurrent write error.
4747
paddle::platform::CudaAtomicAdd(
4848
tensor_out_ptr + index,
49-
-1.0 * learning_rate[0] * selected_rows_ptr[index]);
49+
-static_cast<T>(1.0) * learning_rate[0] * selected_rows_ptr[index]);
5050
}
5151
}
5252
}
@@ -122,5 +122,7 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
122122
} // namespace paddle
123123

124124
namespace ops = paddle::operators;
125+
namespace plat = paddle::platform;
125126
REGISTER_OP_CUDA_KERNEL(sgd, ops::SGDOpCUDAKernel<float>,
126-
ops::SGDOpCUDAKernel<double>);
127+
ops::SGDOpCUDAKernel<double>,
128+
ops::SGDOpCUDAKernel<plat::float16>);
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import paddle.fluid.core as core
18+
import math
19+
import os
20+
import sys
21+
import unittest
22+
23+
import numpy as np
24+
import paddle
25+
import paddle.fluid as fluid
26+
from simple_nets import init_data
27+
from parallel_executor_test_base import TestParallelExecutorBase
28+
29+
batch_size = 12
30+
img_shape = [1, 28, 28]
31+
32+
33+
def loss_net(hidden, label):
34+
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
35+
loss = fluid.layers.cross_entropy(input=prediction, label=label)
36+
avg_loss = fluid.layers.mean(loss)
37+
return avg_loss
38+
39+
40+
def conv_net(use_feed):
41+
img = fluid.layers.data(name='image', shape=img_shape, dtype='float16')
42+
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
43+
44+
conv_pool_1 = fluid.nets.simple_img_conv_pool(
45+
input=img,
46+
filter_size=5,
47+
num_filters=20,
48+
pool_size=2,
49+
pool_stride=2,
50+
act="relu")
51+
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
52+
53+
conv_pool_1 = fluid.layers.cast(conv_pool_1, np.float32)
54+
conv_pool_2 = fluid.nets.simple_img_conv_pool(
55+
input=conv_pool_1,
56+
filter_size=5,
57+
num_filters=50,
58+
pool_size=2,
59+
pool_stride=2,
60+
act="relu")
61+
hidden = fluid.layers.cast(conv_pool_2, np.float32)
62+
return loss_net(hidden, label)
63+
64+
65+
def _optimizer(learning_rate=1e-6):
66+
optimizer = fluid.optimizer.SGD(learning_rate=learning_rate)
67+
return optimizer
68+
69+
70+
class TestResnet(TestParallelExecutorBase):
71+
def check_model(self, use_cuda):
72+
img, label = init_data(
73+
batch_size=batch_size, img_shape=img_shape, label_range=9)
74+
img = np.float16(img).view(np.uint16)
75+
feed_dict = {"image": img, "label": label}
76+
77+
TestParallelExecutorBase.check_network_convergence(
78+
conv_net,
79+
feed_dict=feed_dict,
80+
iter=10,
81+
use_cuda=use_cuda,
82+
fuse_all_reduce_ops=True,
83+
optimizer=_optimizer)
84+
85+
def test_model(self):
86+
if core.is_compiled_with_cuda():
87+
self.check_model(True)
88+
89+
90+
if __name__ == '__main__':
91+
unittest.main()

0 commit comments

Comments
 (0)