Skip to content

Commit 07462de

Browse files
author
Yibing Liu
authored
Cherry-pick lod_reset fix to 1.4 (#16939)
test=release/1.4
1 parent d1c5da2 commit 07462de

File tree

4 files changed

+37
-7
lines changed

4 files changed

+37
-7
lines changed

paddle/fluid/framework/var_type_inference.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,16 @@ class InferVarTypeContext {
4545

4646
virtual bool HasInput(const std::string& name) const {
4747
PADDLE_ENFORCE_NOT_NULL(op_);
48-
return op_->Inputs().count(name) > 0;
48+
auto& inputs = op_->Inputs();
49+
auto input = inputs.find(name);
50+
return input != inputs.end() && !input->second.empty();
4951
}
5052

5153
virtual bool HasOutput(const std::string& name) const {
5254
PADDLE_ENFORCE_NOT_NULL(op_);
53-
return op_->Outputs().count(name) > 0;
55+
auto& outputs = op_->Outputs();
56+
auto output = outputs.find(name);
57+
return output != outputs.end() && !output->second.empty();
5458
}
5559

5660
virtual const std::vector<std::string>& Input(const std::string& name) const {

paddle/fluid/operators/lod_reset_op.cc

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ class LoDResetOp : public framework::OperatorWithKernel {
3030

3131
if (!ctx->HasInput("Y")) {
3232
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
33-
PADDLE_ENFORCE_GT(level0.size(), 1,
33+
PADDLE_ENFORCE_GT(level0.size(), 0,
3434
"If Input(Y) not provided, the target lod should be "
3535
"specified by attribute `target_lod`.");
36-
} else {
36+
} else if (ctx->IsRuntime()) {
3737
ctx->ShareLoD("Y", "Out");
3838
}
3939

@@ -48,6 +48,23 @@ class LoDResetOp : public framework::OperatorWithKernel {
4848
}
4949
};
5050

51+
class LoDResetOpVarTypeInference : public framework::VarTypeInference {
52+
public:
53+
void operator()(framework::InferVarTypeContext *ctx) const override {
54+
auto x_var_name = ctx->Input("X").front();
55+
auto out_var_name = ctx->Output("Out").front();
56+
if (ctx->HasInput("Y")) {
57+
auto y_var_name = ctx->Input("Y").front();
58+
auto y_lod_level = std::max(ctx->GetLoDLevel(y_var_name), 1);
59+
ctx->SetLoDLevel(out_var_name, y_lod_level);
60+
} else {
61+
ctx->SetLoDLevel(out_var_name, 1);
62+
}
63+
ctx->SetDataType(out_var_name, ctx->GetDataType(x_var_name));
64+
ctx->SetType(out_var_name, paddle::framework::proto::VarType::LOD_TENSOR);
65+
}
66+
};
67+
5168
class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
5269
public:
5370
void Make() override {
@@ -177,9 +194,10 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(LoDResetGradNoNeedBufferVarInference,
177194

178195
namespace ops = paddle::operators;
179196
REGISTER_OPERATOR(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker,
180-
ops::LoDResetGradDescMaker);
197+
ops::LoDResetGradDescMaker, ops::LoDResetOpVarTypeInference);
181198
REGISTER_OPERATOR(lod_reset_grad, ops::LoDResetGradOp,
182199
ops::LoDResetGradNoNeedBufferVarInference);
200+
183201
REGISTER_OP_CPU_KERNEL(
184202
lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
185203
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,

paddle/fluid/operators/lod_reset_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class LoDResetKernel : public framework::OpKernel<T> {
6363
"Target LoD should be a vector end with the "
6464
"first dimension of Input(X).");
6565
for (size_t i = 0; i < level0.size() - 1; ++i) {
66-
PADDLE_ENFORCE(level0[i + 1] > level0[i],
66+
PADDLE_ENFORCE(level0[i + 1] >= level0[i],
6767
"Target LoD should be an ascending vector.");
6868
}
6969

python/paddle/fluid/tests/unittests/test_layers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1333,7 +1333,15 @@ def test_lod_reset(self):
13331333
x = layers.data(name='x', shape=[10], dtype='float32')
13341334
y = layers.data(
13351335
name='y', shape=[10, 20], dtype='float32', lod_level=2)
1336-
print(layers.lod_reset(x=x, y=y))
1336+
z = layers.lod_reset(x=x, y=y)
1337+
self.assertTrue(z.lod_level == 2)
1338+
# case 2
1339+
lod_tensor_in = layers.data(name='lod_in', shape=[1], dtype='int64')
1340+
z = layers.lod_reset(x=x, y=lod_tensor_in)
1341+
self.assertTrue(z.lod_level == 1)
1342+
# case 3
1343+
z = layers.lod_reset(x=x, target_lod=[1, 2, 3])
1344+
self.assertTrue(z.lod_level == 1)
13371345
print(str(program))
13381346

13391347
def test_label_smooth(self):

0 commit comments

Comments
 (0)