@@ -93,6 +93,14 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
93
93
RunImpl (scope, place);
94
94
}
95
95
96
+ bool OperatorBase::HasInputs (const std::string& name) const {
97
+ if (inputs_.find (name) != inputs_.end ()) {
98
+ return true ;
99
+ } else {
100
+ return false ;
101
+ }
102
+ }
103
+
96
104
std::string OperatorBase::Input (const std::string& name) const {
97
105
auto & ins = Inputs (name);
98
106
PADDLE_ENFORCE_LE (ins.size (), 1UL ,
@@ -109,6 +117,14 @@ const std::vector<std::string>& OperatorBase::Inputs(
109
117
return it->second ;
110
118
}
111
119
120
+ bool OperatorBase::HasOutputs (const std::string& name) const {
121
+ if (outputs_.find (name) != outputs_.end ()) {
122
+ return true ;
123
+ } else {
124
+ return false ;
125
+ }
126
+ }
127
+
112
128
std::string OperatorBase::Output (const std::string& name) const {
113
129
auto & outs = Outputs (name);
114
130
PADDLE_ENFORCE_LE (outs.size (), 1UL ,
@@ -220,13 +236,18 @@ void OperatorBase::CheckAllInputOutputSet() const {
220
236
if (op_info == nullptr || op_info->proto_ == nullptr ) return ;
221
237
222
238
for (auto & in : op_info->Proto ().inputs ()) {
223
- PADDLE_ENFORCE (inputs_.find (in.name ()) != inputs_.end (),
224
- " Type %s's input %s is not set" , Type (), in.name ());
239
+ if (!in.dispensable ()) {
240
+ PADDLE_ENFORCE (inputs_.find (in.name ()) != inputs_.end (),
241
+ " Operator %s's input, %s, is not set" , Type (), in.name ());
242
+ }
225
243
}
226
244
227
245
for (auto & out : op_info->Proto ().outputs ()) {
228
- PADDLE_ENFORCE (outputs_.find (out.name ()) != outputs_.end (),
229
- " Type %s's output %s is not set" , Type (), out.name ());
246
+ if (!out.dispensable ()) {
247
+ PADDLE_ENFORCE (outputs_.find (out.name ()) != outputs_.end (),
248
+ " Operator %s's output, %s, is not set" , Type (),
249
+ out.name ());
250
+ }
230
251
}
231
252
}
232
253
@@ -332,6 +353,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
332
353
: op_(op), scope_(scope) {}
333
354
334
355
bool HasInput (const std::string& name) const override {
356
+ if (!op_.HasInputs (name)) {
357
+ return false ;
358
+ }
335
359
auto & ins = Inputs (name);
336
360
size_t length = ins.size ();
337
361
if (length == 0 ) {
@@ -345,6 +369,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
345
369
}
346
370
347
371
bool HasOutput (const std::string& name) const override {
372
+ if (!op_.HasOutputs (name)) {
373
+ return false ;
374
+ }
348
375
auto & outs = Outputs (name);
349
376
size_t length = outs.size ();
350
377
if (length == 0 ) {
@@ -358,6 +385,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
358
385
}
359
386
360
387
bool HasInputs (const std::string& name) const override {
388
+ if (!op_.HasInputs (name)) {
389
+ return false ;
390
+ }
361
391
auto inputs = op_.Inputs (name);
362
392
if (inputs.empty ()) {
363
393
return false ;
@@ -371,6 +401,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
371
401
}
372
402
373
403
bool HasOutputs (const std::string& name) const override {
404
+ if (!op_.HasOutputs (name)) {
405
+ return false ;
406
+ }
374
407
auto outputs = op_.Outputs (name);
375
408
if (outputs.empty ()) {
376
409
return false ;
0 commit comments