Skip to content

Commit 01bbe53

Browse files
author
chengduo
authored
Merge pull request #11079 from chengduoZH/balance_parameter_update
Balance parameter opt
2 parents 5987057 + e330cd0 commit 01bbe53

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14-
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
14+
#include <algorithm>
1515
#include <fstream>
16+
#include <string>
1617
#include <utility>
18+
#include <vector>
19+
1720
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
1821
#include "paddle/fluid/framework/details/computation_op_handle.h"
22+
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
1923
#include "paddle/fluid/framework/details/reduce_op_handle.h"
2024
#include "paddle/fluid/framework/details/rpc_op_handle.h"
2125
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
@@ -26,9 +30,6 @@
2630
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
2731
#endif
2832

29-
#include <string>
30-
#include <vector>
31-
3233
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
3334
"the ssa graph path only print with GLOG_v=10,"
3435
"default /tmp/graph.dot");
@@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
148149

149150
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
150151
const ProgramDesc &program) const {
151-
std::unordered_map<std::string, proto::VarType::Type> var_types;
152+
std::unordered_map<std::string, VarDesc *> all_vars;
152153
for (auto *var : program.Block(0).AllVars()) {
153-
var_types[var->Name()] = var->GetType();
154+
all_vars[var->Name()] = var;
154155
}
155156

156157
auto graph = new SSAGraph();
@@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
167168
auto send_vars = FindDistTrainSendVars(program);
168169
auto recv_vars = FindDistTrainRecvVars(program);
169170

170-
size_t cur_device_id = 0;
171171
std::vector<std::unordered_set<std::string>> var_name_on_devices;
172172
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
173173
var_name_on_devices.resize(places_.size());
174174
bcast_var_name_set.resize(places_.size());
175175

176+
size_t cur_device_id = 0;
177+
std::vector<int64_t> balance_grads(places_.size(), 0);
178+
179+
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
180+
auto var_desc = all_vars.at(g_name);
181+
PADDLE_ENFORCE_NOT_NULL(var_desc);
182+
auto dim = framework::make_ddim(var_desc->GetShape());
183+
int64_t numel = framework::product(dim);
184+
PADDLE_ENFORCE_GE(numel, 0);
185+
auto smallest =
186+
std::min_element(std::begin(balance_grads), std::end(balance_grads));
187+
size_t dev_id =
188+
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
189+
balance_grads[dev_id] += numel;
190+
return dev_id;
191+
};
192+
176193
bool is_forwarding = true;
177194
for (auto *op : program.Block(0).AllOps()) {
178195
if (boost::get<int>(
@@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
220237

221238
switch (strategy_.reduce_) {
222239
case BuildStrategy::ReduceStrategy::kReduce:
240+
cur_device_id = get_appropriate_dev(g_name);
223241
CreateReduceOp(&result, g_name, cur_device_id);
224242
var_name_on_devices[cur_device_id].emplace(g_name);
225243
bcast_var_name_set[cur_device_id].emplace(p_name);
226-
cur_device_id = (cur_device_id + 1) % places_.size();
227244
break;
228245
case BuildStrategy::ReduceStrategy::kAllReduce:
229-
if (IsSparseGradient(var_types, g_name)) {
246+
if (IsSparseGradient(all_vars, g_name)) {
230247
CreateReduceOp(&result, g_name, 0);
231248
CreateBroadcastOp(&result, g_name, 0);
232249
} else {
@@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
269286
}
270287

271288
bool MultiDevSSAGraphBuilder::IsSparseGradient(
272-
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
289+
const std::unordered_map<std::string, VarDesc *> &all_vars,
273290
const std::string &og) const {
274-
PADDLE_ENFORCE(var_types.count(og) != 0);
275-
if (var_types.at(og) == proto::VarType::SELECTED_ROWS) {
291+
PADDLE_ENFORCE(all_vars.count(og) != 0);
292+
if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
276293
return true;
277294
}
278295
return false;

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
106106
size_t src_dev_id) const;
107107

108108
bool IsSparseGradient(
109-
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
109+
const std::unordered_map<std::string, VarDesc *> &all_vars,
110110
const std::string &og) const;
111111

112112
private:

0 commit comments

Comments
 (0)