Skip to content

Commit 6084af4

Browse files
authored
Fix the bug when a input variable of op is dispensable. (#10268)
* Fix the bug when a input variable of op is dispensable. * Add HasInputs/Outputs interfaces to OperatorBase. * Remove the unreferenced header file.
1 parent 8a0c7e2 commit 6084af4

File tree

4 files changed

+40
-6
lines changed

4 files changed

+40
-6
lines changed

paddle/capi/Matrix.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ paddle_error paddle_matrix_get_row(paddle_matrix mat,
108108
paddle_error paddle_matrix_get_shape(paddle_matrix mat,
109109
uint64_t* height,
110110
uint64_t* width) {
111-
if (mat == nullptr) return kPD_NULLPTR;
111+
if (mat == nullptr || cast(mat)->mat == nullptr) return kPD_NULLPTR;
112112
if (height != nullptr) {
113113
*height = cast(mat)->mat->getHeight();
114114
}

paddle/fluid/framework/operator.cc

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
9393
RunImpl(scope, place);
9494
}
9595

96+
bool OperatorBase::HasInputs(const std::string& name) const {
97+
if (inputs_.find(name) != inputs_.end()) {
98+
return true;
99+
} else {
100+
return false;
101+
}
102+
}
103+
96104
std::string OperatorBase::Input(const std::string& name) const {
97105
auto& ins = Inputs(name);
98106
PADDLE_ENFORCE_LE(ins.size(), 1UL,
@@ -109,6 +117,14 @@ const std::vector<std::string>& OperatorBase::Inputs(
109117
return it->second;
110118
}
111119

120+
bool OperatorBase::HasOutputs(const std::string& name) const {
121+
if (outputs_.find(name) != outputs_.end()) {
122+
return true;
123+
} else {
124+
return false;
125+
}
126+
}
127+
112128
std::string OperatorBase::Output(const std::string& name) const {
113129
auto& outs = Outputs(name);
114130
PADDLE_ENFORCE_LE(outs.size(), 1UL,
@@ -220,13 +236,18 @@ void OperatorBase::CheckAllInputOutputSet() const {
220236
if (op_info == nullptr || op_info->proto_ == nullptr) return;
221237

222238
for (auto& in : op_info->Proto().inputs()) {
223-
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
224-
"Type %s's input %s is not set", Type(), in.name());
239+
if (!in.dispensable()) {
240+
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
241+
"Operator %s's input, %s, is not set", Type(), in.name());
242+
}
225243
}
226244

227245
for (auto& out : op_info->Proto().outputs()) {
228-
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
229-
"Type %s's output %s is not set", Type(), out.name());
246+
if (!out.dispensable()) {
247+
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
248+
"Operator %s's output, %s, is not set", Type(),
249+
out.name());
250+
}
230251
}
231252
}
232253

@@ -332,6 +353,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
332353
: op_(op), scope_(scope) {}
333354

334355
bool HasInput(const std::string& name) const override {
356+
if (!op_.HasInputs(name)) {
357+
return false;
358+
}
335359
auto& ins = Inputs(name);
336360
size_t length = ins.size();
337361
if (length == 0) {
@@ -345,6 +369,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
345369
}
346370

347371
bool HasOutput(const std::string& name) const override {
372+
if (!op_.HasOutputs(name)) {
373+
return false;
374+
}
348375
auto& outs = Outputs(name);
349376
size_t length = outs.size();
350377
if (length == 0) {
@@ -358,6 +385,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
358385
}
359386

360387
bool HasInputs(const std::string& name) const override {
388+
if (!op_.HasInputs(name)) {
389+
return false;
390+
}
361391
auto inputs = op_.Inputs(name);
362392
if (inputs.empty()) {
363393
return false;
@@ -371,6 +401,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
371401
}
372402

373403
bool HasOutputs(const std::string& name) const override {
404+
if (!op_.HasOutputs(name)) {
405+
return false;
406+
}
374407
auto outputs = op_.Outputs(name);
375408
if (outputs.empty()) {
376409
return false;

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,15 @@ class OperatorBase {
105105
const VariableNameMap& Inputs() const { return inputs_; }
106106
const VariableNameMap& Outputs() const { return outputs_; }
107107

108+
bool HasInputs(const std::string& name) const;
108109
//! Get a input with argument's name described in `op_proto`
109110
std::string Input(const std::string& name) const;
110111
//! Get a input which has multiple variables.
111112
const std::vector<std::string>& Inputs(const std::string& name) const;
112113
//! Get all inputs variable names
113114
std::vector<std::string> InputVars() const;
114115

116+
bool HasOutputs(const std::string& name) const;
115117
//! Get a output with argument's name described in `op_proto`
116118
std::string Output(const std::string& name) const;
117119
//! Get an output which has multiple variables.

paddle/fluid/platform/profiler.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License. */
1818
#include <string>
1919
#include <vector>
2020
#include "paddle/fluid/platform/device_context.h"
21-
#include "paddle/fluid/platform/profiler.pb.h"
2221

2322
namespace paddle {
2423
namespace platform {

0 commit comments

Comments
 (0)