@@ -555,18 +555,17 @@ Tensor* ExecutionContext::LegacyOutput<Tensor>(const std::string& name) const {
555
555
template <>
556
556
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
557
557
const std::string& name) const {
558
- auto names = op ().Outputs (name);
558
+ auto it = ctx_.outputs .find (name);
559
+ if (it == ctx_.outputs .end ()) {
560
+ return {};
561
+ }
562
+ const std::vector<Variable*>& vars = it->second ;
559
563
std::vector<Tensor*> res;
560
- res.reserve (names.size ());
561
- std::transform (names.begin (), names.end (), std::back_inserter (res),
562
- [&](const std::string& sub_name) -> Tensor* {
563
- auto var = scope_.FindVar (sub_name);
564
- if (var == nullptr ) return nullptr ;
565
- PADDLE_ENFORCE (
566
- var->IsType <LoDTensor>(),
567
- " %s should be LoDTensor, but the received type is %s" ,
568
- sub_name, ToTypeName (var->Type ()));
569
- return var->GetMutable <LoDTensor>();
564
+ res.reserve (vars.size ());
565
+ std::transform (vars.begin (), vars.end (), std::back_inserter (res),
566
+ [&](Variable* var) -> Tensor* {
567
+ return var == nullptr ? nullptr
568
+ : var->GetMutable <LoDTensor>();
570
569
});
571
570
return res;
572
571
}
0 commit comments