Skip to content

Commit d087866

Browse files
authored
Merge pull request #12005 from reyoung/feature/fix_data_balance_on_one_gpu
Fix data balance on single GPU
2 parents 89704d9 + 8e86721 commit d087866

File tree

5 files changed

+26
-10
lines changed

5 files changed

+26
-10
lines changed

paddle/fluid/framework/details/build_strategy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct BuildStrategy {
3434

3535
std::string debug_graphviz_path_{""};
3636

37-
bool enable_data_balance_{true};
37+
bool enable_data_balance_{false};
3838
};
3939

4040
} // namespace details

paddle/fluid/framework/details/data_balance_op_handle.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
8686
}
8787

8888
void DataBalanceOpHandle::RunImpl() {
89-
if (places_.size() == 1) {
90-
return;
91-
}
89+
PADDLE_ENFORCE_GT(places_.size(), 1,
90+
"Data balance can only be enabled when the number of "
91+
"places to run larger than 1.");
9292
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
9393
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
9494
PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0);

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
5959
grad_names_.insert(GradVarName(p));
6060
}
6161
balance_vars_.resize(places_.size(), 0);
62+
if (strategy_.enable_data_balance_ && places_.size() == 1) {
63+
LOG(WARNING) << "It is no need to enable data balance when there is only "
64+
"one place. enable_data_balance is set to False.";
65+
strategy_.enable_data_balance_ = false;
66+
}
6267
}
6368

6469
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,

paddle/fluid/operators/read_op.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
9292
void Make() override {
9393
AddInput("Reader", "(ReaderHolder) The executed reader.");
9494
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable();
95-
AddAttr<bool>("throw_eof_exp",
96-
"If set true, an exception will be thrown when the Reader "
97-
"yields empty (which means there is no next data).")
95+
AddAttr<bool>(
96+
"throw_eof_exp",
97+
"If set true, an exception will be thrown when the Reader "
98+
"yields empty (which means there is no next data).\n"
99+
"NOTES: This flag must be true always. It will be set to false"
100+
" only when the data-balance is enabled in ParallelExecutor"
101+
" and it is set by ParallelExecutor instance, not users.")
98102
.SetDefault(true);
99103
AddComment(R"DOC(
100104
Read Operator

python/paddle/fluid/tests/unittests/test_data_balance.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,12 @@ def main(self):
103103
exe = fluid.Executor(place)
104104
exe.run(startup_prog)
105105

106+
build_strategy = fluid.BuildStrategy()
107+
build_strategy.enable_data_balance = True
106108
parallel_exe = fluid.ParallelExecutor(
107-
use_cuda=self.use_cuda, main_program=main_prog)
109+
use_cuda=self.use_cuda,
110+
main_program=main_prog,
111+
build_strategy=build_strategy)
108112

109113
if (parallel_exe.device_count > self.batch_size):
110114
print("WARNING: Unittest TestDataBalance skipped. \
@@ -145,9 +149,12 @@ def main_lod(self):
145149
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
146150
exe = fluid.Executor(place)
147151
exe.run(startup_prog)
148-
152+
build_strategy = fluid.BuildStrategy()
153+
build_strategy.enable_data_balance = True
149154
parallel_exe = fluid.ParallelExecutor(
150-
use_cuda=self.use_cuda, main_program=main_prog)
155+
use_cuda=self.use_cuda,
156+
main_program=main_prog,
157+
build_strategy=build_strategy)
151158

152159
if (parallel_exe.device_count > self.batch_size):
153160
print("WARNING: Unittest TestDataBalance skipped. \

0 commit comments

Comments
 (0)