Skip to content

Commit 85c203b

Browse files
Make bilinear_interp_op support attrs from input. (#11041)
* Make bilinear_interp_op support attrs from input. * Fix python api.
1 parent bfecb57 commit 85c203b

File tree

5 files changed

+111
-15
lines changed

5 files changed

+111
-15
lines changed

paddle/fluid/operators/bilinear_interp_op.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,22 @@ class BilinearInterpOp : public framework::OperatorWithKernel {
3434
int out_w = ctx->Attrs().Get<int>("out_w");
3535
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
3636

37+
if (ctx->HasInput("OutSize")) {
38+
auto out_size_dim = ctx->GetInputDim("OutSize");
39+
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
40+
"OutSize's dimension size must be 1");
41+
PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
42+
}
3743
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
3844
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
3945
}
46+
47+
protected:
48+
framework::OpKernelType GetExpectedKernelType(
49+
const framework::ExecutionContext& ctx) const override {
50+
return framework::OpKernelType(
51+
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
52+
}
4053
};
4154

4255
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -45,6 +58,10 @@ class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
4558
AddInput("X",
4659
"(Tensor) The input tensor of bilinear interpolation, "
4760
"This is a 4-D tensor with shape of (N x C x h x w)");
61+
AddInput("OutSize",
62+
"(Tensor) This is a 1-D tensor with two number. "
63+
"The first number is height and the second number is width.")
64+
.AsDispensable();
4865
AddOutput("Out",
4966
"(Tensor) The dimension of output is (N x C x out_h x out_w]");
5067

@@ -78,6 +95,12 @@ class BilinearInterpOpGrad : public framework::OperatorWithKernel {
7895
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
7996
}
8097
}
98+
99+
framework::OpKernelType GetExpectedKernelType(
100+
const framework::ExecutionContext& ctx) const override {
101+
return framework::OpKernelType(
102+
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
103+
}
81104
};
82105

83106
} // namespace operators

paddle/fluid/operators/bilinear_interp_op.cu

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,21 @@ class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
102102
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
103103
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
104104
auto* input = input_t->data<T>();
105-
auto* output = output_t->mutable_data<T>(ctx.GetPlace());
106105

107106
int out_h = ctx.Attr<int>("out_h");
108107
int out_w = ctx.Attr<int>("out_w");
108+
auto out_dims = output_t->dims();
109+
auto out_size_t = ctx.Input<Tensor>("OutSize");
110+
if (out_size_t != nullptr) {
111+
Tensor sizes;
112+
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
113+
auto size_data = sizes.data<int>();
114+
out_h = size_data[0];
115+
out_w = size_data[1];
116+
}
117+
auto* output = output_t->mutable_data<T>(
118+
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
119+
109120
int batch_size = input_t->dims()[0];
110121
int channels = input_t->dims()[1];
111122
int in_h = input_t->dims()[2];
@@ -139,8 +150,8 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
139150
void Compute(const framework::ExecutionContext& ctx) const override {
140151
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
141152
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
142-
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
143153
auto* d_output = d_output_t->data<T>();
154+
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
144155

145156
auto& device_ctx =
146157
ctx.template device_context<platform::CUDADeviceContext>();
@@ -149,6 +160,16 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
149160

150161
int out_h = ctx.Attr<int>("out_h");
151162
int out_w = ctx.Attr<int>("out_w");
163+
164+
auto out_size_t = ctx.Input<Tensor>("OutSize");
165+
if (out_size_t != nullptr) {
166+
Tensor sizes;
167+
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
168+
auto size_data = sizes.data<int>();
169+
out_h = size_data[0];
170+
out_w = size_data[1];
171+
}
172+
152173
int batch_size = d_input_t->dims()[0];
153174
int channels = d_input_t->dims()[1];
154175
int in_h = d_input_t->dims()[2];

paddle/fluid/operators/bilinear_interp_op.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,18 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
2424
void Compute(const framework::ExecutionContext& ctx) const override {
2525
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
2626
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
27+
auto out_dims = output_t->dims();
2728
auto* input = input_t->data<T>();
28-
auto* output = output_t->mutable_data<T>(ctx.GetPlace());
29-
3029
int out_h = ctx.Attr<int>("out_h");
3130
int out_w = ctx.Attr<int>("out_w");
31+
auto out_size_t = ctx.Input<Tensor>("OutSize");
32+
if (out_size_t != nullptr) {
33+
auto out_size_data = out_size_t->data<int>();
34+
out_h = out_size_data[0];
35+
out_w = out_size_data[1];
36+
}
37+
auto* output = output_t->mutable_data<T>(
38+
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
3239
int batch_size = input_t->dims()[0];
3340
int channels = input_t->dims()[1];
3441
int in_h = input_t->dims()[2];
@@ -83,16 +90,23 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
8390
void Compute(const framework::ExecutionContext& ctx) const override {
8491
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
8592
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
86-
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
8793
auto* d_output = d_output_t->data<T>();
88-
94+
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
8995
auto& device_ctx =
9096
ctx.template device_context<platform::CPUDeviceContext>();
9197
math::SetConstant<platform::CPUDeviceContext, T> zero;
9298
zero(device_ctx, d_input_t, static_cast<T>(0.0));
9399

94100
int out_h = ctx.Attr<int>("out_h");
95101
int out_w = ctx.Attr<int>("out_w");
102+
103+
auto out_size_t = ctx.Input<Tensor>("OutSize");
104+
if (out_size_t != nullptr) {
105+
auto out_size_data = out_size_t->data<int>();
106+
out_h = out_size_data[0];
107+
out_w = out_size_data[1];
108+
}
109+
96110
int batch_size = d_input_t->dims()[0];
97111
int channels = d_input_t->dims()[1];
98112
int in_h = d_input_t->dims()[2];

python/paddle/fluid/layers/nn.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3944,7 +3944,7 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
39443944
input (Variable): The input tensor of bilinear interpolation,
39453945
This is a 4-D tensor of the shape
39463946
(num_batches, channels, in_h, in_w).
3947-
out_shape(list|tuple|None): Output shape of bilinear interpolation
3947+
out_shape(list|tuple|Variable|None): Output shape of bilinear interpolation
39483948
layer, the shape is (out_h, out_w).
39493949
Default: None
39503950
scale(int|None): The multiplier for the input height or width.
@@ -3971,21 +3971,28 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
39713971
def _is_list_or_turple_(data):
39723972
return (isinstance(data, list) or isinstance(data, tuple))
39733973

3974+
out_h = 0
3975+
out_w = 0
3976+
inputs = {"X": input}
39743977
if out_shape is not None:
3975-
if not (_is_list_or_turple_(out_shape) and len(out_shape) == 2):
3978+
if not (_is_list_or_turple_(out_shape) and len(out_shape) == 2) and (
3979+
out_shape is not Variable):
39763980
raise ValueError('out_shape should be a list or tuple ',
39773981
'with length 2, (out_h, out_w).')
3978-
out_shape = list(map(int, out_shape))
3979-
out_h = out_shape[0]
3980-
out_w = out_shape[1]
3982+
if _is_list_or_turple_(out_shape):
3983+
out_shape = list(map(int, out_shape))
3984+
out_h = out_shape[0]
3985+
out_w = out_shape[1]
3986+
else:
3987+
inputs['OutSize'] = out_shape
39813988
else:
39823989
out_h = int(input.shape[2] * scale)
39833990
out_w = int(input.shape[3] * scale)
39843991

39853992
out = helper.create_tmp_variable(dtype)
39863993
helper.append_op(
39873994
type="bilinear_interp",
3988-
inputs={"X": input},
3995+
inputs=inputs,
39893996
outputs={"Out": out},
39903997
attrs={"out_h": out_h,
39913998
"out_w": out_w})

python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from op_test import OpTest
1818

1919

20-
def bilinear_interp_np(input, out_h, out_w):
20+
def bilinear_interp_np(input, out_h, out_w, out_size):
21+
if out_size is not None:
22+
out_h = out_size[0]
23+
out_w = out_size[1]
2124
batch_size, channel, in_h, in_w = input.shape
2225
if out_h > 1:
2326
ratio_h = (in_h - 1.0) / (out_h - 1.0)
@@ -49,12 +52,15 @@ def bilinear_interp_np(input, out_h, out_w):
4952

5053
class TestBilinearInterpOp(OpTest):
5154
def setUp(self):
55+
self.out_size = None
5256
self.init_test_case()
5357
self.op_type = "bilinear_interp"
5458
input_np = np.random.random(self.input_shape).astype("float32")
55-
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w)
56-
59+
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
60+
self.out_size)
5761
self.inputs = {'X': input_np}
62+
if self.out_size is not None:
63+
self.inputs['OutSize'] = self.out_size
5864
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
5965
self.outputs = {'Out': output_np}
6066

@@ -68,6 +74,7 @@ def init_test_case(self):
6874
self.input_shape = [2, 3, 4, 4]
6975
self.out_h = 2
7076
self.out_w = 2
77+
self.out_size = np.array([3, 3]).astype("int32")
7178

7279

7380
class TestCase1(TestBilinearInterpOp):
@@ -91,5 +98,29 @@ def init_test_case(self):
9198
self.out_w = 128
9299

93100

101+
class TestCase4(TestBilinearInterpOp):
102+
def init_test_case(self):
103+
self.input_shape = [4, 1, 7, 8]
104+
self.out_h = 1
105+
self.out_w = 1
106+
self.out_size = np.array([2, 2]).astype("int32")
107+
108+
109+
class TestCase5(TestBilinearInterpOp):
110+
def init_test_case(self):
111+
self.input_shape = [3, 3, 9, 6]
112+
self.out_h = 12
113+
self.out_w = 12
114+
self.out_size = np.array([11, 11]).astype("int32")
115+
116+
117+
class TestCase6(TestBilinearInterpOp):
118+
def init_test_case(self):
119+
self.input_shape = [1, 1, 128, 64]
120+
self.out_h = 64
121+
self.out_w = 128
122+
self.out_size = np.array([65, 129]).astype("int32")
123+
124+
94125
if __name__ == "__main__":
95126
unittest.main()

0 commit comments

Comments
 (0)