Skip to content

Commit ee11f00

Browse files
authored
add shareLod (#5259)
* add shareLod * fix sequence_conv grad infershape
1 parent 360cb18 commit ee11f00

File tree

5 files changed

+33
-7
lines changed

5 files changed

+33
-7
lines changed

paddle/framework/op_desc.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,22 @@ class CompileTimeInferShapeContext : public InferShapeContext {
5252
const std::vector<std::string> &Outputs(
5353
const std::string &name) const override;
5454

55+
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
56+
size_t j = 0) const override {
57+
PADDLE_ENFORCE_LT(i, Inputs(in).size());
58+
PADDLE_ENFORCE_LT(j, Outputs(out).size());
59+
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
60+
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
61+
if (in_var->GetType() != VarDesc::LOD_TENSOR) {
62+
VLOG(3) << "input " << in << "is not LodTensor";
63+
return;
64+
}
65+
PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR,
66+
"The %d-th output of Output(%s) must be LoDTensor.", j,
67+
out);
68+
in_var->SetLoDLevel(out_var->GetLodLevel());
69+
}
70+
5571
private:
5672
DDim GetDim(const std::string &name) const override;
5773

paddle/framework/operator.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,20 @@ class RuntimeInferShapeContext : public InferShapeContext {
351351
return op_.Outputs(name);
352352
}
353353

354+
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
355+
size_t j = 0) const override {
356+
PADDLE_ENFORCE_LT(i, Inputs(in).size());
357+
PADDLE_ENFORCE_LT(j, Outputs(out).size());
358+
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
359+
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
360+
if (!in_var->IsType<LoDTensor>()) return;
361+
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
362+
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
363+
auto in_tensor = in_var->Get<LoDTensor>();
364+
auto* out_tensor = out_var->GetMutable<LoDTensor>();
365+
out_tensor->set_lod(in_tensor.lod());
366+
}
367+
354368
private:
355369
DDim GetDim(const std::string& name) const override {
356370
Variable* var = scope_.FindVar(name);

paddle/framework/shape_inference.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ void InferShapeContext::SetOutputsDim(
2828
SetDims(names, dims);
2929
}
3030

31-
void InferShapeContext::ShareLoD(const std::string &in, const std::string &out,
32-
size_t i, size_t j) const {}
33-
3431
std::vector<framework::DDim> InferShapeContext::GetDims(
3532
const std::vector<std::string> &names) const {
3633
std::vector<framework::DDim> ret;

paddle/framework/shape_inference.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ class InferShapeContext {
4343
virtual const std::vector<std::string> &Outputs(
4444
const std::string &name) const = 0;
4545

46-
// TODO(qiao) implement this function
47-
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
48-
size_t j = 0) const;
46+
virtual void ShareLoD(const std::string &in, const std::string &out,
47+
size_t i = 0, size_t j = 0) const = 0;
4948

5049
protected:
5150
virtual framework::DDim GetDim(const std::string &name) const = 0;

paddle/operators/sequence_conv_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
8989
}
9090
if (ctx->HasOutput(framework::GradVarName("X"))) {
9191
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
92-
ctx->ShareLoD(framework::GradVarName("X"), "X");
92+
ctx->ShareLoD("X", framework::GradVarName("X"));
9393
}
9494
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
9595
ctx->SetOutputDim(framework::GradVarName("Filter"),

0 commit comments

Comments
 (0)