Skip to content

Commit ab6600e

Browse files
author
chengduo
authored
Fix bug of FastThreadedExecutor (#16666)
test=release/1.4
1 parent 4cc6144 commit ab6600e

File tree

4 files changed

+44
-13
lines changed

4 files changed

+44
-13
lines changed

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
5656
fetches.resize(fetch_tensors.size());
5757
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
5858
std::vector<FetchOpHandle *> fetch_ops;
59+
std::vector<OpHandleBase *> ready_fetch_ops;
5960

6061
for (auto &fetch_var_name : fetch_tensors) {
6162
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
@@ -70,8 +71,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
7071
auto &var_name = fetch_tensors[i];
7172
auto fetched_var_it = fetched_vars.find(var_name);
7273
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
73-
"Cannot find fetched variable.(Perhaps the main_program "
74-
"is not set to ParallelExecutor)");
74+
"Cannot find fetched variable(%s).(Perhaps the main_program "
75+
"is not set to ParallelExecutor)",
76+
var_name);
7577

7678
auto &vars = fetched_var_it->second;
7779

@@ -88,7 +90,11 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
8890
op->AddInput(var);
8991
}
9092

91-
(*op_deps)[op] = static_cast<int>(op->NotReadyInputSize());
93+
int dep = static_cast<int>(op->NotReadyInputSize());
94+
(*op_deps)[op] = dep;
95+
if (dep == 0) {
96+
ready_fetch_ops.emplace_back(op);
97+
}
9298
}
9399

94100
size_t num_complete = 0;
@@ -97,7 +103,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
97103
for (auto op : bootstrap_ops_) {
98104
RunOpAsync(op_deps.get(), op, complete_q);
99105
}
100-
106+
for (auto op : ready_fetch_ops) {
107+
RunOpAsync(op_deps.get(), op, complete_q);
108+
}
101109
while (num_complete != op_deps->size()) {
102110
size_t num_comp = complete_q->Pop();
103111
if (num_comp == -1UL) {

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/details/fetch_op_handle.h"
16-
1716
#include <string>
1817
#include <vector>
18+
#include "paddle/fluid/platform/profiler.h"
1919

2020
namespace paddle {
2121
namespace framework {
@@ -44,6 +44,7 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
4444
}
4545

4646
void FetchOpHandle::RunImpl() {
47+
platform::RecordEvent record_event(Name());
4748
WaitInputVarGenerated(platform::CPUPlace());
4849

4950
tensors_.resize(inputs_.size());
@@ -62,7 +63,8 @@ void FetchOpHandle::RunImpl() {
6263
auto &t = var->Get<framework::LoDTensor>();
6364
if (platform::is_gpu_place(t.place())) {
6465
#ifdef PADDLE_WITH_CUDA
65-
TensorCopySync(t, cpu, &tensors_[i]);
66+
TensorCopy(t, cpu, *dev_ctxes_.at(t.place()), &tensors_[i]);
67+
dev_ctxes_.at(t.place())->Wait();
6668
#endif
6769
} else {
6870
tensors_[i].ShareDataWith(t);

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
6868
}
6969
set.clear();
7070
};
71-
auto run_all_op = [&](OpHandleBase *op) { RunOp(ready_vars, op); };
71+
7272
// Clean run context
7373
run_op_futures_.clear();
7474
exception_holder_.Clear();
@@ -102,7 +102,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
102102
auto &deps = pending_ops[op];
103103
--deps;
104104
if (deps == 0) {
105-
run_all_op(op);
105+
ready_ops.insert(op);
106106
}
107107
}
108108
}

python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,15 @@ def Lenet(data, class_dim):
3838

3939

4040
class TestFetchAndFeed(unittest.TestCase):
41-
def parallel_exe(self, use_cuda, run_parallel_exe, seed=1):
41+
@classmethod
42+
def setUpClass(cls):
43+
os.environ['CPU_NUM'] = str(4)
44+
45+
def parallel_exe(self,
46+
use_cuda,
47+
run_parallel_exe,
48+
use_experimental_executor=False,
49+
seed=1):
4250
main_program = fluid.Program()
4351
startup = fluid.Program()
4452
startup.random_seed = seed
@@ -63,8 +71,12 @@ def parallel_exe(self, use_cuda, run_parallel_exe, seed=1):
6371
build_strategy = fluid.BuildStrategy()
6472
build_strategy.enable_inplace = False
6573
build_strategy.memory_optimize = False
74+
exec_strategy = fluid.ExecutionStrategy()
75+
exec_strategy.use_experimental_executor = use_experimental_executor
6676
train_cp = compiler.CompiledProgram(main_program).with_data_parallel(
67-
loss_name=loss.name, build_strategy=build_strategy)
77+
loss_name=loss.name,
78+
build_strategy=build_strategy,
79+
exec_strategy=exec_strategy)
6880

6981
run_parallel_exe(train_cp, exe, use_cuda, data, label, loss)
7082

@@ -131,17 +143,26 @@ def get_data(batch_size=8):
131143
if batch_id == 2:
132144
break
133145

134-
def test_fetch(self):
135-
os.environ['CPU_NUM'] = str(4)
146+
def test_fetch_with_threaded_executor(self):
136147
if core.is_compiled_with_cuda():
137148
self.parallel_exe(
138149
use_cuda=True,
139150
run_parallel_exe=self.run_parallel_exe_with_fetch)
140151
self.parallel_exe(
141152
use_cuda=False, run_parallel_exe=self.run_parallel_exe_with_fetch)
142153

154+
def test_fetch_with_fast_threaded_executor(self):
155+
if core.is_compiled_with_cuda():
156+
self.parallel_exe(
157+
use_cuda=True,
158+
run_parallel_exe=self.run_parallel_exe_with_fetch,
159+
use_experimental_executor=True)
160+
self.parallel_exe(
161+
use_cuda=False,
162+
run_parallel_exe=self.run_parallel_exe_with_fetch,
163+
use_experimental_executor=True)
164+
143165
def test_feed(self):
144-
os.environ['CPU_NUM'] = str(4)
145166
if core.is_compiled_with_cuda():
146167
self.parallel_exe(
147168
use_cuda=True, run_parallel_exe=self.run_parallel_exe_with_feed)

0 commit comments

Comments
 (0)