Skip to content

Commit 8bb824b

Browse files
committed
refine infershape hasinput and hasoutput
1 parent c4394bc commit 8bb824b

File tree

6 files changed

+197
-311
lines changed

6 files changed

+197
-311
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 155 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ limitations under the License. */
2121
#include "paddle/fluid/framework/lod_tensor.h"
2222
#include "paddle/fluid/framework/operator.h"
2323
#include "paddle/fluid/framework/shape_inference.h"
24-
#include "paddle/fluid/framework/shape_runtime_infer.h"
2524
#include "paddle/fluid/framework/var_type.h"
2625
#include "paddle/fluid/platform/profiler.h"
2726

@@ -459,147 +458,184 @@ bool OpSupportGPU(const std::string& op_type) {
459458
return false;
460459
}
461460

462-
bool RuntimeInferShapeContext::HasInput(const std::string& name) const {
463-
if (!op_.HasInputs(name)) {
464-
return false;
465-
}
466-
auto& ins = Inputs(name);
467-
size_t length = ins.size();
468-
if (length == 0) {
469-
return false;
470-
}
471-
PADDLE_ENFORCE_EQ(length, 1UL,
472-
"Input %s should not have more than one inputs", name);
473-
auto ipt = ins[0];
474-
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
475-
return var != nullptr;
476-
}
461+
class RuntimeInferShapeContext : public InferShapeContext {
462+
public:
463+
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
464+
: op_(op), scope_(scope) {}
477465

478-
bool RuntimeInferShapeContext::HasOutput(const std::string& name) const {
479-
if (!op_.HasOutputs(name)) {
480-
return false;
481-
}
482-
auto& outs = Outputs(name);
483-
size_t length = outs.size();
484-
if (length == 0) {
485-
return false;
486-
}
487-
PADDLE_ENFORCE_EQ(length, 1UL,
488-
"Output %s should not have more than one inputs", name);
489-
auto ipt = outs[0];
490-
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
491-
return var != nullptr;
492-
}
466+
bool HasInput(const std::string& name) const override {
467+
// has only one input
468+
const auto& ins = op_.Inputs();
469+
auto it = ins.find(name);
470+
if (it == ins.end()) {
471+
return false;
472+
}
473+
const auto& in = it->second;
493474

494-
bool RuntimeInferShapeContext::HasInputs(const std::string& name) const {
495-
if (!op_.HasInputs(name)) {
496-
return false;
497-
}
498-
auto inputs = op_.Inputs(name);
499-
if (inputs.empty()) {
500-
return false;
501-
}
502-
for (auto& input : inputs) {
503-
if (scope_.FindVar(input) == nullptr) {
475+
if (in.size() != 1 || in[0] == kEmptyVarName) {
504476
return false;
505477
}
478+
return scope_.FindVar(in[0]) != nullptr;
506479
}
507-
return true;
508-
}
509480

510-
bool RuntimeInferShapeContext::HasOutputs(const std::string& name) const {
511-
if (!op_.HasOutputs(name)) {
512-
return false;
481+
bool HasOutput(const std::string& name) const override {
482+
// has only one output
483+
const auto& outs = op_.Outputs();
484+
auto it = outs.find(name);
485+
if (it == outs.end()) {
486+
return false;
487+
}
488+
const auto& out = it->second;
489+
if (out.size() != 1 || out[0] == kEmptyVarName) {
490+
return false;
491+
}
492+
return scope_.FindVar(out[0]) != nullptr;
513493
}
514-
auto outputs = op_.Outputs(name);
515-
if (outputs.empty()) {
516-
return false;
494+
495+
bool HasInputs(const std::string& name) const override {
496+
if (!op_.HasInputs(name)) {
497+
return false;
498+
}
499+
auto inputs = op_.Inputs(name);
500+
if (inputs.empty()) {
501+
return false;
502+
}
503+
for (auto& input : inputs) {
504+
if (scope_.FindVar(input) == nullptr) {
505+
return false;
506+
}
507+
}
508+
return true;
517509
}
518-
for (auto& output : outputs) {
519-
if (scope_.FindVar(output) == nullptr) {
510+
511+
bool HasOutputs(const std::string& name) const override {
512+
if (!op_.HasOutputs(name)) {
513+
return false;
514+
}
515+
auto outputs = op_.Outputs(name);
516+
if (outputs.empty()) {
520517
return false;
521518
}
519+
for (auto& output : outputs) {
520+
if (scope_.FindVar(output) == nullptr) {
521+
return false;
522+
}
523+
}
524+
return true;
522525
}
523-
return true;
524-
}
525526

526-
void RuntimeInferShapeContext::ShareLoD(const std::string& in,
527-
const std::string& out, size_t i,
528-
size_t j) const {
529-
PADDLE_ENFORCE_LT(i, Inputs(in).size());
530-
PADDLE_ENFORCE_LT(j, Outputs(out).size());
531-
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
532-
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
533-
if (!in_var->IsType<LoDTensor>()) return;
534-
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
535-
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
536-
auto in_tensor = in_var->Get<LoDTensor>();
537-
auto* out_tensor = out_var->GetMutable<LoDTensor>();
538-
out_tensor->set_lod(in_tensor.lod());
527+
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
528+
529+
const std::vector<std::string>& Inputs(
530+
const std::string& name) const override {
531+
return op_.Inputs(name);
532+
}
533+
534+
const std::vector<std::string>& Outputs(
535+
const std::string& name) const override {
536+
return op_.Outputs(name);
537+
}
538+
539+
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
540+
size_t j = 0) const override {
541+
PADDLE_ENFORCE_LT(i, Inputs(in).size());
542+
PADDLE_ENFORCE_LT(j, Outputs(out).size());
543+
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
544+
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
545+
if (!in_var->IsType<LoDTensor>()) return;
546+
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
547+
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
548+
auto in_tensor = in_var->Get<LoDTensor>();
549+
auto* out_tensor = out_var->GetMutable<LoDTensor>();
550+
out_tensor->set_lod(in_tensor.lod());
539551

540552
// TODO(dzhwinter) : reuse ShareLoD in most operators.
541553
// Need to call ShareLayout explicitly in sequence related ops.
542554
// Shall we have a better method to shared info between in/out Tensor?
543555
#ifdef PADDLE_WITH_MKLDNN
544-
// Fix me: ugly workaround below
545-
// Correct solution:
546-
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
547-
// layout of output tensor should be set "manually" in Compute()
548-
// of each OPKernel. The reason layout should NOT be shared between
549-
// input and output "automatically" (now by InferShape()->ShareLoD())
550-
// is that layout transform may occur after InferShape().
551-
// Workaround:
552-
// Skip set_layout() when input layout is kMKLDNN
553-
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
554-
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
555-
// in Compute()
556-
if (in_tensor.layout() != DataLayout::kMKLDNN)
556+
// Fix me: ugly workaround below
557+
// Correct solution:
558+
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
559+
// layout of output tensor should be set "manually" in Compute()
560+
// of each OPKernel. The reason layout should NOT be shared between
561+
// input and output "automatically" (now by InferShape()->ShareLoD())
562+
// is that layout transform may occur after InferShape().
563+
// Workaround:
564+
// Skip set_layout() when input layout is kMKLDNN
565+
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
566+
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
567+
// in Compute()
568+
if (in_tensor.layout() != DataLayout::kMKLDNN)
557569
#endif
570+
out_tensor->set_layout(in_tensor.layout());
571+
}
572+
573+
void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
574+
size_t j = 0) const {
575+
PADDLE_ENFORCE_LT(i, Inputs(in).size());
576+
PADDLE_ENFORCE_LT(j, Outputs(out).size());
577+
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
578+
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
579+
if (!in_var->IsType<LoDTensor>()) return;
580+
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
581+
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
582+
auto in_tensor = in_var->Get<LoDTensor>();
583+
auto* out_tensor = out_var->GetMutable<LoDTensor>();
558584
out_tensor->set_layout(in_tensor.layout());
559-
}
585+
}
560586

561-
void RuntimeInferShapeContext::ShareLayout(const std::string& in,
562-
const std::string& out, size_t i,
563-
size_t j) const {
564-
PADDLE_ENFORCE_LT(i, Inputs(in).size());
565-
PADDLE_ENFORCE_LT(j, Outputs(out).size());
566-
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
567-
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
568-
if (!in_var->IsType<LoDTensor>()) return;
569-
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
570-
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
571-
auto in_tensor = in_var->Get<LoDTensor>();
572-
auto* out_tensor = out_var->GetMutable<LoDTensor>();
573-
out_tensor->set_layout(in_tensor.layout());
574-
}
575-
576-
DDim RuntimeInferShapeContext::GetDim(const std::string& name) const {
577-
Variable* var = scope_.FindVar(name);
578-
PADDLE_ENFORCE_NOT_NULL(var);
579-
if (var->IsType<LoDTensor>()) {
580-
return var->Get<LoDTensor>().dims();
581-
} else if (var->IsType<SelectedRows>()) {
582-
return var->Get<SelectedRows>().GetCompleteDims();
583-
} else {
584-
PADDLE_THROW(
585-
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
586-
"type_id is %s.",
587-
name, var->Type().name());
587+
bool IsRuntime() const override { return true; }
588+
589+
protected:
590+
DDim GetDim(const std::string& name) const override {
591+
Variable* var = scope_.FindVar(name);
592+
PADDLE_ENFORCE_NOT_NULL(var);
593+
if (var->IsType<LoDTensor>()) {
594+
return var->Get<LoDTensor>().dims();
595+
} else if (var->IsType<SelectedRows>()) {
596+
return var->Get<SelectedRows>().GetCompleteDims();
597+
} else {
598+
PADDLE_THROW(
599+
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
600+
"type_id is %s.",
601+
name, var->Type().name());
602+
}
588603
}
589-
}
590604

591-
void RuntimeInferShapeContext::SetDim(const std::string& name,
592-
const DDim& dim) {
593-
Variable* var = scope_.FindVar(name);
594-
if (var->IsType<LoDTensor>()) {
595-
var->GetMutable<LoDTensor>()->Resize(dim);
596-
} else if (var->IsType<SelectedRows>()) {
597-
var->GetMutable<SelectedRows>()->set_height(dim[0]);
598-
} else {
599-
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", name,
600-
var->Type().name());
605+
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
606+
PADDLE_THROW("Only compile time support this method");
601607
}
602-
}
608+
609+
void SetDim(const std::string& name, const DDim& dim) override {
610+
Variable* var = scope_.FindVar(name);
611+
if (var->IsType<LoDTensor>()) {
612+
var->GetMutable<LoDTensor>()->Resize(dim);
613+
} else if (var->IsType<SelectedRows>()) {
614+
var->GetMutable<SelectedRows>()->set_height(dim[0]);
615+
} else {
616+
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.",
617+
name, var->Type().name());
618+
}
619+
}
620+
621+
void SetRepeatedDims(const std::string& name,
622+
const std::vector<DDim>& dims) override {
623+
PADDLE_THROW("Only compile time support this method");
624+
}
625+
626+
proto::VarType::Type GetVarType(const std::string& name) const override {
627+
auto* var = scope_.FindVar(name);
628+
return ToVarType(var->Type());
629+
}
630+
631+
InferShapeVarPtr GetVarPtr(const std::string& name) override {
632+
return scope_.FindVar(name);
633+
}
634+
635+
private:
636+
const OperatorBase& op_;
637+
const Scope& scope_;
638+
};
603639

604640
static void CheckTensorNANOrInf(const std::string& name,
605641
const framework::Tensor& tensor) {

paddle/fluid/framework/shape_runtime_infer.h

Lines changed: 0 additions & 86 deletions
This file was deleted.

0 commit comments

Comments
 (0)