Skip to content

Commit 8284947

Browse files
wanghaoshuangqingqing01
authored andcommitted
Fix infershape of im2sequence. (#12183)
1 parent d113590 commit 8284947

File tree

2 files changed

+5
-12
lines changed

2 files changed

+5
-12
lines changed

paddle/fluid/operators/im2sequence_op.cc

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,14 @@ class Im2SequenceOp : public framework::OperatorWithKernel {
3333

3434
PADDLE_ENFORCE_EQ(in_dim.size(), 4,
3535
"Input(X) format must be 4D tensor, eg., NCHW.");
36-
int batch_size = in_dim[0];
3736
int img_channels = in_dim[1];
38-
int img_height = in_dim[2];
39-
int img_width = in_dim[3];
4037

4138
auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels");
4239
auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
4340
auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
4441

45-
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
46-
paddings[2], strides[0]);
47-
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
48-
paddings[3], strides[1]);
49-
50-
ctx->SetOutputDim("Out", {batch_size * output_height * output_width,
51-
img_channels * kernels[0] * kernels[1]});
42+
ctx->SetOutputDim("Out",
43+
{in_dim[0], img_channels * kernels[0] * kernels[1]});
5244
}
5345
};
5446

paddle/fluid/operators/im2sequence_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,13 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
109109
}
110110
out->set_lod(lod);
111111
} else {
112-
out->mutable_data<T>(ctx.GetPlace());
113112
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
114113
paddings[2], strides[0]);
115114
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
116115
paddings[3], strides[1]);
117-
116+
out->mutable_data<T>({batch_size * output_height * output_width,
117+
img_channels * kernels[0] * kernels[1]},
118+
ctx.GetPlace());
118119
const std::vector<int> dilations({1, 1});
119120
auto out_dims = out->dims();
120121
out->Resize({batch_size, out->numel() / batch_size});

0 commit comments

Comments
 (0)