Skip to content

Commit fd286f3

Browse files
authored
Merge pull request #15534 from JiabinYang/fix/multi_output_support_imperative
test=develop, fix/multi_output_support_imperative
2 parents bf7dedc + 5639f49 commit fd286f3

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -555,18 +555,17 @@ Tensor* ExecutionContext::LegacyOutput<Tensor>(const std::string& name) const {
555555
template <>
556556
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
557557
const std::string& name) const {
558-
auto names = op().Outputs(name);
558+
auto it = ctx_.outputs.find(name);
559+
if (it == ctx_.outputs.end()) {
560+
return {};
561+
}
562+
const std::vector<Variable*>& vars = it->second;
559563
std::vector<Tensor*> res;
560-
res.reserve(names.size());
561-
std::transform(names.begin(), names.end(), std::back_inserter(res),
562-
[&](const std::string& sub_name) -> Tensor* {
563-
auto var = scope_.FindVar(sub_name);
564-
if (var == nullptr) return nullptr;
565-
PADDLE_ENFORCE(
566-
var->IsType<LoDTensor>(),
567-
"%s should be LoDTensor, but the received type is %s",
568-
sub_name, ToTypeName(var->Type()));
569-
return var->GetMutable<LoDTensor>();
564+
res.reserve(vars.size());
565+
std::transform(vars.begin(), vars.end(), std::back_inserter(res),
566+
[&](Variable* var) -> Tensor* {
567+
return var == nullptr ? nullptr
568+
: var->GetMutable<LoDTensor>();
570569
});
571570
return res;
572571
}

0 commit comments

Comments
 (0)