Skip to content

Commit e0436ad

Browse files
committed
refine fusion lstm infershape
1 parent 94b66bd commit e0436ad

File tree

3 files changed

+260
-184
lines changed

3 files changed

+260
-184
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 119 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include "paddle/fluid/framework/lod_tensor.h"
2222
#include "paddle/fluid/framework/operator.h"
2323
#include "paddle/fluid/framework/shape_inference.h"
24+
#include "paddle/fluid/framework/shape_runtime_infer.h"
2425
#include "paddle/fluid/framework/var_type.h"
2526
#include "paddle/fluid/platform/profiler.h"
2627

@@ -458,187 +459,147 @@ bool OpSupportGPU(const std::string& op_type) {
458459
return false;
459460
}
460461

461-
class RuntimeInferShapeContext : public InferShapeContext {
462-
public:
463-
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
464-
: op_(op), scope_(scope) {}
465-
466-
bool HasInput(const std::string& name) const override {
467-
if (!op_.HasInputs(name)) {
468-
return false;
469-
}
470-
auto& ins = Inputs(name);
471-
size_t length = ins.size();
472-
if (length == 0) {
473-
return false;
474-
}
475-
PADDLE_ENFORCE_EQ(length, 1UL,
476-
"Input %s should not have more than one inputs", name);
477-
auto ipt = ins[0];
478-
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
479-
return var != nullptr;
462+
bool RuntimeInferShapeContext::HasInput(const std::string& name) const {
463+
if (!op_.HasInputs(name)) {
464+
return false;
480465
}
481-
482-
bool HasOutput(const std::string& name) const override {
483-
if (!op_.HasOutputs(name)) {
484-
return false;
485-
}
486-
auto& outs = Outputs(name);
487-
size_t length = outs.size();
488-
if (length == 0) {
489-
return false;
490-
}
491-
PADDLE_ENFORCE_EQ(length, 1UL,
492-
"Output %s should not have more than one inputs", name);
493-
auto ipt = outs[0];
494-
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
495-
return var != nullptr;
466+
auto& ins = Inputs(name);
467+
size_t length = ins.size();
468+
if (length == 0) {
469+
return false;
496470
}
471+
PADDLE_ENFORCE_EQ(length, 1UL,
472+
"Input %s should not have more than one inputs", name);
473+
auto ipt = ins[0];
474+
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
475+
return var != nullptr;
476+
}
497477

498-
bool HasInputs(const std::string& name) const override {
499-
if (!op_.HasInputs(name)) {
500-
return false;
501-
}
502-
auto inputs = op_.Inputs(name);
503-
if (inputs.empty()) {
504-
return false;
505-
}
506-
for (auto& input : inputs) {
507-
if (scope_.FindVar(input) == nullptr) {
508-
return false;
509-
}
510-
}
511-
return true;
478+
bool RuntimeInferShapeContext::HasOutput(const std::string& name) const {
479+
if (!op_.HasOutputs(name)) {
480+
return false;
512481
}
482+
auto& outs = Outputs(name);
483+
size_t length = outs.size();
484+
if (length == 0) {
485+
return false;
486+
}
487+
PADDLE_ENFORCE_EQ(length, 1UL,
488+
"Output %s should not have more than one inputs", name);
489+
auto ipt = outs[0];
490+
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
491+
return var != nullptr;
492+
}
513493

514-
bool HasOutputs(const std::string& name) const override {
515-
if (!op_.HasOutputs(name)) {
516-
return false;
517-
}
518-
auto outputs = op_.Outputs(name);
519-
if (outputs.empty()) {
494+
bool RuntimeInferShapeContext::HasInputs(const std::string& name) const {
495+
if (!op_.HasInputs(name)) {
496+
return false;
497+
}
498+
auto inputs = op_.Inputs(name);
499+
if (inputs.empty()) {
500+
return false;
501+
}
502+
for (auto& input : inputs) {
503+
if (scope_.FindVar(input) == nullptr) {
520504
return false;
521505
}
522-
for (auto& output : outputs) {
523-
if (scope_.FindVar(output) == nullptr) {
524-
return false;
525-
}
526-
}
527-
return true;
528506
}
507+
return true;
508+
}
529509

530-
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
531-
532-
const std::vector<std::string>& Inputs(
533-
const std::string& name) const override {
534-
return op_.Inputs(name);
510+
bool RuntimeInferShapeContext::HasOutputs(const std::string& name) const {
511+
if (!op_.HasOutputs(name)) {
512+
return false;
535513
}
536-
537-
const std::vector<std::string>& Outputs(
538-
const std::string& name) const override {
539-
return op_.Outputs(name);
514+
auto outputs = op_.Outputs(name);
515+
if (outputs.empty()) {
516+
return false;
540517
}
518+
for (auto& output : outputs) {
519+
if (scope_.FindVar(output) == nullptr) {
520+
return false;
521+
}
522+
}
523+
return true;
524+
}
541525

542-
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
543-
size_t j = 0) const override {
544-
PADDLE_ENFORCE_LT(i, Inputs(in).size());
545-
PADDLE_ENFORCE_LT(j, Outputs(out).size());
546-
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
547-
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
548-
if (!in_var->IsType<LoDTensor>()) return;
549-
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
550-
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
551-
auto in_tensor = in_var->Get<LoDTensor>();
552-
auto* out_tensor = out_var->GetMutable<LoDTensor>();
553-
out_tensor->set_lod(in_tensor.lod());
526+
void RuntimeInferShapeContext::ShareLoD(const std::string& in,
527+
const std::string& out, size_t i,
528+
size_t j) const {
529+
PADDLE_ENFORCE_LT(i, Inputs(in).size());
530+
PADDLE_ENFORCE_LT(j, Outputs(out).size());
531+
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
532+
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
533+
if (!in_var->IsType<LoDTensor>()) return;
534+
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
535+
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
536+
auto in_tensor = in_var->Get<LoDTensor>();
537+
auto* out_tensor = out_var->GetMutable<LoDTensor>();
538+
out_tensor->set_lod(in_tensor.lod());
554539

555540
// TODO(dzhwinter) : reuse ShareLoD in most operators.
556541
// Need to call ShareLayout explicitly in sequence related ops.
557542
// Shall we have a better method to shared info between in/out Tensor?
558543
#ifdef PADDLE_WITH_MKLDNN
559-
// Fix me: ugly workaround below
560-
// Correct solution:
561-
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
562-
// layout of output tensor should be set "manually" in Compute()
563-
// of each OPKernel. The reason layout should NOT be shared between
564-
// input and output "automatically" (now by InferShape()->ShareLoD())
565-
// is that layout transform may occur after InferShape().
566-
// Workaround:
567-
// Skip set_layout() when input layout is kMKLDNN
568-
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
569-
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
570-
// in Compute()
571-
if (in_tensor.layout() != DataLayout::kMKLDNN)
544+
// Fix me: ugly workaround below
545+
// Correct solution:
546+
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
547+
// layout of output tensor should be set "manually" in Compute()
548+
// of each OPKernel. The reason layout should NOT be shared between
549+
// input and output "automatically" (now by InferShape()->ShareLoD())
550+
// is that layout transform may occur after InferShape().
551+
// Workaround:
552+
// Skip set_layout() when input layout is kMKLDNN
553+
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
554+
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
555+
// in Compute()
556+
if (in_tensor.layout() != DataLayout::kMKLDNN)
572557
#endif
573-
out_tensor->set_layout(in_tensor.layout());
574-
}
575-
576-
void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
577-
size_t j = 0) const {
578-
PADDLE_ENFORCE_LT(i, Inputs(in).size());
579-
PADDLE_ENFORCE_LT(j, Outputs(out).size());
580-
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
581-
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
582-
if (!in_var->IsType<LoDTensor>()) return;
583-
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
584-
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
585-
auto in_tensor = in_var->Get<LoDTensor>();
586-
auto* out_tensor = out_var->GetMutable<LoDTensor>();
587558
out_tensor->set_layout(in_tensor.layout());
588-
}
589-
590-
bool IsRuntime() const override { return true; }
591-
592-
protected:
593-
DDim GetDim(const std::string& name) const override {
594-
Variable* var = scope_.FindVar(name);
595-
PADDLE_ENFORCE_NOT_NULL(var);
596-
if (var->IsType<LoDTensor>()) {
597-
return var->Get<LoDTensor>().dims();
598-
} else if (var->IsType<SelectedRows>()) {
599-
return var->Get<SelectedRows>().GetCompleteDims();
600-
} else {
601-
PADDLE_THROW(
602-
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
603-
"type_id is %s.",
604-
name, var->Type().name());
605-
}
606-
}
607-
608-
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
609-
PADDLE_THROW("Only compile time support this method");
610-
}
611-
612-
void SetDim(const std::string& name, const DDim& dim) override {
613-
Variable* var = scope_.FindVar(name);
614-
if (var->IsType<LoDTensor>()) {
615-
var->GetMutable<LoDTensor>()->Resize(dim);
616-
} else if (var->IsType<SelectedRows>()) {
617-
var->GetMutable<SelectedRows>()->set_height(dim[0]);
618-
} else {
619-
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.",
620-
name, var->Type().name());
621-
}
622-
}
623-
624-
void SetRepeatedDims(const std::string& name,
625-
const std::vector<DDim>& dims) override {
626-
PADDLE_THROW("Only compile time support this method");
627-
}
559+
}
628560

629-
proto::VarType::Type GetVarType(const std::string& name) const override {
630-
auto* var = scope_.FindVar(name);
631-
return ToVarType(var->Type());
561+
void RuntimeInferShapeContext::ShareLayout(const std::string& in,
562+
const std::string& out, size_t i,
563+
size_t j) const {
564+
PADDLE_ENFORCE_LT(i, Inputs(in).size());
565+
PADDLE_ENFORCE_LT(j, Outputs(out).size());
566+
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
567+
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
568+
if (!in_var->IsType<LoDTensor>()) return;
569+
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
570+
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
571+
auto in_tensor = in_var->Get<LoDTensor>();
572+
auto* out_tensor = out_var->GetMutable<LoDTensor>();
573+
out_tensor->set_layout(in_tensor.layout());
574+
}
575+
576+
DDim RuntimeInferShapeContext::GetDim(const std::string& name) const {
577+
Variable* var = scope_.FindVar(name);
578+
PADDLE_ENFORCE_NOT_NULL(var);
579+
if (var->IsType<LoDTensor>()) {
580+
return var->Get<LoDTensor>().dims();
581+
} else if (var->IsType<SelectedRows>()) {
582+
return var->Get<SelectedRows>().GetCompleteDims();
583+
} else {
584+
PADDLE_THROW(
585+
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
586+
"type_id is %s.",
587+
name, var->Type().name());
632588
}
589+
}
633590

634-
InferShapeVarPtr GetVarPtr(const std::string& name) override {
635-
return scope_.FindVar(name);
591+
void RuntimeInferShapeContext::SetDim(const std::string& name,
592+
const DDim& dim) {
593+
Variable* var = scope_.FindVar(name);
594+
if (var->IsType<LoDTensor>()) {
595+
var->GetMutable<LoDTensor>()->Resize(dim);
596+
} else if (var->IsType<SelectedRows>()) {
597+
var->GetMutable<SelectedRows>()->set_height(dim[0]);
598+
} else {
599+
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", name,
600+
var->Type().name());
636601
}
637-
638-
private:
639-
const OperatorBase& op_;
640-
const Scope& scope_;
641-
};
602+
}
642603

643604
static void CheckTensorNANOrInf(const std::string& name,
644605
const framework::Tensor& tensor) {

0 commit comments

Comments
 (0)