Skip to content

Commit 28dc9ba

Browse files
Add shape op to get the shape of variable. (#11048)
* Add shape op to get the shape of variable. * Rename get_shape to shape. * Add checker for output and fix comments.
1 parent ed36591 commit 28dc9ba

File tree

5 files changed

+160
-0
lines changed

5 files changed

+160
-0
lines changed

paddle/fluid/operators/shape_op.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/shape_op.h"
16+
#include "paddle/fluid/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class ShapeOp : public framework::OperatorWithKernel {
22+
public:
23+
using framework::OperatorWithKernel::OperatorWithKernel;
24+
25+
void InferShape(framework::InferShapeContext *ctx) const override {
26+
PADDLE_ENFORCE(ctx->HasInput("Input"),
27+
"Input (Input) of get_shape op should not be null.");
28+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
29+
"Output (Out) of get_shape op should not be null.");
30+
auto in_dim = ctx->GetInputDim("Input");
31+
ctx->SetOutputDim("Out", {in_dim.size()});
32+
}
33+
};
34+
35+
class ShapeOpMaker : public framework::OpProtoAndCheckerMaker {
36+
public:
37+
void Make() override {
38+
AddInput("Input", "(Tensor), The input tensor.");
39+
AddOutput("Out", "(Tensor), The shape of input tensor.");
40+
AddComment(R"DOC(
41+
Shape Operator.
42+
Get the shape of input tensor.
43+
)DOC");
44+
}
45+
};
46+
47+
} // namespace operators
48+
} // namespace paddle
49+
50+
namespace ops = paddle::operators;
51+
REGISTER_OPERATOR(shape, ops::ShapeOp, ops::ShapeOpMaker,
52+
paddle::framework::EmptyGradOpMaker);
53+
REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<int>, ops::ShapeKernel<int64_t>,
54+
ops::ShapeKernel<float>, ops::ShapeKernel<double>);

paddle/fluid/operators/shape_op.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/shape_op.h"
16+
17+
REGISTER_OP_CUDA_KERNEL(shape, paddle::operators::ShapeKernel<int>,
18+
paddle::operators::ShapeKernel<int64_t>,
19+
paddle::operators::ShapeKernel<float>,
20+
paddle::operators::ShapeKernel<double>);

paddle/fluid/operators/shape_op.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <algorithm>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
24+
template <typename T>
25+
class ShapeKernel : public framework::OpKernel<T> {
26+
public:
27+
void Compute(const framework::ExecutionContext& ctx) const override {
28+
auto* in_t = ctx.Input<Tensor>("Input");
29+
auto* out_t = ctx.Output<Tensor>("Out");
30+
auto out_data = out_t->mutable_data<int64_t>(platform::CPUPlace());
31+
auto in_dims = in_t->dims();
32+
for (int i = 0; i < in_dims.size(); ++i) {
33+
out_data[i] = in_dims[i];
34+
}
35+
}
36+
};
37+
} // namespace operators
38+
} // namespace paddle

python/paddle/fluid/layers/ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
'cumsum',
7272
'scatter',
7373
'sum',
74+
'shape',
7475
] + __activations__
7576

7677
for _OP in set(__all__):
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
from op_test import OpTest
18+
19+
20+
class TestShapeOp(OpTest):
21+
def setUp(self):
22+
self.op_type = "shape"
23+
self.config()
24+
self.shape = [2, 3]
25+
input = np.zeros(self.shape)
26+
self.inputs = {'Input': input}
27+
self.outputs = {'Out': np.array(self.shape)}
28+
29+
def config(self):
30+
self.shape = [2, 3]
31+
32+
def test_check_output(self):
33+
self.check_output()
34+
35+
36+
class case1(TestShapeOp):
37+
def config(self):
38+
self.shape = [2]
39+
40+
41+
class case2(TestShapeOp):
42+
def config(self):
43+
self.shape = [1, 2, 3]
44+
45+
46+
if __name__ == '__main__':
47+
unittest.main()

0 commit comments

Comments
 (0)