Skip to content

Commit 8b1b756

Browse files
authored
Merge pull request #10671 from chengduoZH/fix_fetch_op_handle
Refine fetch op handle
2 parents 7ebb246 + 624caee commit 8b1b756

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,18 @@ void FetchOpHandle::RunImpl() {
4848
WaitInputVarGenerated(platform::CPUPlace());
4949

5050
tensors_.resize(inputs_.size());
51-
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
52-
auto &var_name = var_handle->name_;
5351
platform::CPUPlace cpu;
5452
auto &scopes = *local_scopes_;
5553

56-
for (size_t i = 0; i < scopes.size(); ++i) {
57-
auto &scope = scopes[i];
58-
auto *var =
59-
scope->FindVar(kLocalExecScopeName)->Get<Scope *>()->FindVar(var_name);
54+
for (size_t i = 0; i < inputs_.size(); ++i) {
55+
auto *var_handle = static_cast<VarHandle *>(inputs_[i]);
56+
auto &scope = scopes.at(var_handle->scope_idx_);
57+
auto *var = scope->FindVar(kLocalExecScopeName)
58+
->Get<Scope *>()
59+
->FindVar(var_handle->name_);
6060
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
61-
var_name);
61+
var_handle->name_);
62+
6263
auto &t = var->Get<framework::LoDTensor>();
6364
if (platform::is_gpu_place(t.place())) {
6465
#ifdef PADDLE_WITH_CUDA

python/paddle/fluid/tests/unittests/test_parallel_executor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def test_update_sparse_parameter_reduce(self):
775775
build_strategy = fluid.BuildStrategy()
776776
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
777777
self.check_network_convergence(
778-
is_sparse=False, build_strategy=build_strategy)
778+
is_sparse=True, build_strategy=build_strategy)
779779

780780
def test_update_dense_parameter_reduce(self):
781781
build_strategy = fluid.BuildStrategy()
@@ -849,8 +849,7 @@ def parallel_exe(self, train_inputs, seed):
849849
assert not math.isnan(np.sum(ret[i])) and \
850850
not math.isinf(np.sum(ret[i]))
851851

852-
@unittest.skip("this test is buggy")
853-
def test_feed(self):
852+
def test_fetch_op(self):
854853
tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16)
855854
tst_reader_iter = tst_reader()
856855

0 commit comments

Comments
 (0)