Skip to content

Commit 5803115

Browse files
authored
Merge pull request #11288 from JiayiFeng/fix_bug_of_ExecutionContext
fix bugs in the implementation of 'HasInput' and 'HasOutput'
2 parents 635099c + dc8e0b4 commit 5803115

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,38 @@ static Tensor* GetMutableTensorFromVar(Variable* var) {
293293
}
294294
}
295295

296+
bool ExecutionContext::HasInput(const std::string& name) const {
297+
if (!op_.HasInputs(name)) {
298+
return false;
299+
}
300+
auto& ins = Inputs(name);
301+
size_t length = ins.size();
302+
if (length == 0) {
303+
return false;
304+
}
305+
PADDLE_ENFORCE_EQ(length, 1UL,
306+
"Input %s should not have more than one inputs", name);
307+
auto arg = ins[0];
308+
auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
309+
return var != nullptr;
310+
}
311+
312+
bool ExecutionContext::HasOutput(const std::string& name) const {
313+
if (!op_.HasOutputs(name)) {
314+
return false;
315+
}
316+
auto& outs = Outputs(name);
317+
size_t length = outs.size();
318+
if (length == 0) {
319+
return false;
320+
}
321+
PADDLE_ENFORCE_EQ(length, 1UL,
322+
"Output %s should not have more than one inputs", name);
323+
auto arg = outs[0];
324+
auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
325+
return var != nullptr;
326+
}
327+
296328
template <>
297329
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
298330
auto* var = InputVar(name);

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ class ExecutionContext {
191191
return op_.Attr<T>(name);
192192
}
193193

194-
bool HasInput(const std::string& name) const { return op_.HasInputs(name); }
194+
bool HasInput(const std::string& name) const;
195195

196-
bool HasOutput(const std::string& name) const { return op_.HasOutputs(name); }
196+
bool HasOutput(const std::string& name) const;
197197

198198
size_t InputSize(const std::string& name) const {
199199
return op_.Inputs(name).size();

0 commit comments

Comments
 (0)