Skip to content

Commit e383ea2

Browse files
committed
fix fetch op handle
1 parent 5f6fd26 commit e383ea2

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,18 @@ void FetchOpHandle::RunImpl() {
4848
WaitInputVarGenerated(platform::CPUPlace());
4949

5050
tensors_.resize(inputs_.size());
51-
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
52-
auto &var_name = var_handle->name_;
5351
platform::CPUPlace cpu;
5452
auto &scopes = *local_scopes_;
5553

56-
for (size_t i = 0; i < scopes.size(); ++i) {
57-
auto &scope = scopes[i];
58-
auto *var =
59-
scope->FindVar(kLocalExecScopeName)->Get<Scope *>()->FindVar(var_name);
54+
for (size_t i = 0; i < inputs_.size(); ++i) {
55+
auto *var_handle = static_cast<VarHandle *>(inputs_[i]);
56+
auto &scope = scopes.at(var_handle->scope_idx_);
57+
auto *var = scope->FindVar(kLocalExecScopeName)
58+
->Get<Scope *>()
59+
->FindVar(var_handle->name_);
6060
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
61-
var_name);
61+
var_handle->name_);
62+
6263
auto &t = var->Get<framework::LoDTensor>();
6364
if (platform::is_gpu_place(t.place())) {
6465
#ifdef PADDLE_WITH_CUDA

python/paddle/fluid/tests/unittests/test_parallel_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ def test_update_dense_parameter(self):
764764

765765
def test_update_sparse_parameter_with_new_strategy(self):
766766
self.check_network_convergence(
767-
is_sparse=False, balance_parameter_opt_between_cards=True)
767+
is_sparse=True, balance_parameter_opt_between_cards=True)
768768

769769
def test_update_dense_parameter_with_new_strategy(self):
770770
self.check_network_convergence(
@@ -836,7 +836,7 @@ def parallel_exe(self, train_inputs, seed):
836836
assert not math.isnan(np.sum(ret[i])) and \
837837
not math.isinf(np.sum(ret[i]))
838838

839-
def test_update_sparse_parameter(self):
839+
def test_fetch_op(self):
840840
tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16)
841841
tst_reader_iter = tst_reader()
842842

0 commit comments

Comments
 (0)