Skip to content

Commit ff4317c

Browse files
committed
follow comments
1 parent 3606a30 commit ff4317c

File tree

6 files changed

+17
-4
lines changed

6 files changed

+17
-4
lines changed

paddle/fluid/framework/details/build_strategy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ struct BuildStrategy {
3333
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
3434

3535
std::string debug_graphviz_path_{""};
36+
37+
bool enable_data_balance_{true};
3638
};
3739

3840
} // namespace details

paddle/fluid/framework/details/data_balance_op_handle.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
7373
for (int dst_idx = device_num - empty_num; dst_idx < device_num; ++dst_idx) {
7474
if (size_device_vec[src_idx][0] <= expected_device_size) {
7575
++src_idx;
76-
PADDLE_ENFORCE_LT(src_idx, device_num - empty_num);
76+
PADDLE_ENFORCE_LT(
77+
src_idx, device_num - empty_num,
78+
"In current srategy an empty tensor should not be copy source.");
7779
}
7880
size_device_vec[src_idx][0] -= expected_device_size;
7981
size_device_vec[dst_idx][0] += expected_device_size;
@@ -113,7 +115,9 @@ void DataBalanceOpHandle::RunImpl() {
113115
if (data_idx == 0) {
114116
device_sizes.emplace_back(ins_size);
115117
} else {
116-
PADDLE_ENFORCE_EQ(ins_size, device_sizes.at(place_idx));
118+
PADDLE_ENFORCE_EQ(
119+
ins_size, device_sizes.at(place_idx),
120+
"All data on the same device shall have the same batch size.");
117121
}
118122
}
119123
const auto &balance_plan = GetBalancePlan(device_sizes);

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
216216
} else {
217217
// This op runs on all devices, and its output may have parameter's
218218
// gradients.
219-
if (op->Type() == "read") {
219+
if (op->Type() == "read" && strategy_.enable_data_balance_) {
220220
op->SetAttr("throw_eof_exp", false);
221221
CreateComputationalOps(&result, *op, places_.size());
222222
const auto &data_var_names = op->Output("Out");

paddle/fluid/framework/details/op_handle_base.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ void OpHandleBase::Run(bool use_cuda) {
5858

5959
void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
6060
#ifdef PADDLE_WITH_CUDA
61+
PADDLE_ENFORCE_NOT_NULL(waited_ctx);
6162
if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
6263
for (auto &dev_ctx : dev_ctxes_) {
6364
PADDLE_ENFORCE_NOT_NULL(dev_ctx.second);

paddle/fluid/pybind/pybind.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,11 @@ All parameter, weight, gradient are variables in Paddle.
643643
[](const BuildStrategy &self) { return self.debug_graphviz_path_; },
644644
[](BuildStrategy &self, const std::string &path) {
645645
self.debug_graphviz_path_ = path;
646-
});
646+
})
647+
.def_property(
648+
"enable_data_balance",
649+
[](const BuildStrategy &self) { return self.enable_data_balance_; },
650+
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; });
647651

648652
pe.def(py::init<const std::vector<platform::Place> &,
649653
const std::unordered_set<std::string> &,

python/paddle/fluid/tests/unittests/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ mnist_1.recordio
44
mnist_2.recordio
55
flowers.recordio
66
wmt16.recordio
7+
data_balance_test.recordio
8+
data_balance_with_lod_test.recordio

0 commit comments

Comments
 (0)