Skip to content

Commit a8b3996

Browse files
authored
Merge pull request #7219 from reyoung/feature/correctly_handle_lod_information_for_image_operators
Correctly handle lod information of image operators
2 parents f3c42f6 + 040dc59 commit a8b3996

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

paddle/operators/batch_norm_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
6464
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
6565
"Input X must have 2 to 5 dimensions.");
6666

67-
const int C =
67+
const int64_t C =
6868
(data_layout == DataLayout::kNCHW ? x_dims[1]
6969
: x_dims[x_dims.size() - 1]);
7070

@@ -78,6 +78,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
7878
ctx->SetOutputDim("VarianceOut", {C});
7979
ctx->SetOutputDim("SavedMean", {C});
8080
ctx->SetOutputDim("SavedVariance", {C});
81+
ctx->ShareLoD("X", "Y");
8182
}
8283
};
8384

paddle/operators/conv_op.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
4444
paddings.size(), strides.size(),
4545
"Conv paddings dimension and Conv strides dimension should be the same.");
4646

47-
int input_channels = in_dims[1];
48-
PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
47+
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups,
4948
"The number of input channels should be equal to filter "
5049
"channels * groups.");
5150

52-
int output_channels = filter_dims[0];
5351
PADDLE_ENFORCE_EQ(
54-
output_channels % groups, 0,
52+
filter_dims[0] % groups, 0,
5553
"The number of output channels should be divided by groups.");
5654

5755
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
@@ -66,6 +64,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
6664
dilations[i], paddings[i], strides[i]));
6765
}
6866
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
67+
ctx->ShareLoD("Input", "Output");
6968
}
7069

7170
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)

paddle/operators/pool_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
5858
OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
5959
}
6060
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
61+
ctx->ShareLoD("X", "Out");
6162
}
6263

6364
void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {

0 commit comments

Comments
 (0)