Skip to content

Commit 28013ef

Browse files
authored
Merge pull request #10019 from reyoung/feature/feed_fetch_tensor_on_cpu
Fix FetchTensor on CPU
2 parents 63cca04 + 6de5fd9 commit 28013ef

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,23 @@ void FetchOpHandle::RunImpl() {
5151
auto *var = static_cast<VarHandle *>(input);
5252
var->generated_op_->Wait(cpu_ctx);
5353
}
54-
5554
tensors_.resize(inputs_.size());
56-
auto *var = static_cast<VarHandle *>(inputs_[0]);
57-
auto &var_name = var->name_;
55+
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
56+
auto &var_name = var_handle->name_;
5857
platform::CPUPlace cpu;
5958
auto &scopes = *local_scopes_;
6059

6160
for (size_t i = 0; i < scopes.size(); ++i) {
6261
auto &scope = scopes[i];
63-
auto &t = scope->FindVar(kLocalExecScopeName)
64-
->Get<Scope *>()
65-
->FindVar(var_name)
66-
->Get<framework::LoDTensor>();
67-
if (platform::is_gpu_place(var->place_)) {
62+
auto *var =
63+
scope->FindVar(kLocalExecScopeName)->Get<Scope *>()->FindVar(var_name);
64+
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
65+
var_name);
66+
auto &t = var->Get<framework::LoDTensor>();
67+
if (platform::is_gpu_place(t.place())) {
6868
#ifdef PADDLE_WITH_CUDA
6969
TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]);
70-
dev_ctxes_[t.place()]->Wait();
70+
dev_ctxes_.at(t.place())->Wait();
7171
#endif
7272
} else {
7373
tensors_[i].ShareDataWith(t);

0 commit comments

Comments
 (0)