Skip to content

Commit ff28b1f

Browse files
authored
Merge pull request #14071 from barrierye/add_similarity_focus_op
Add similarity focus op
2 parents 688ed60 + ef8218b commit ff28b1f

File tree

5 files changed

+586
-0
lines changed

5 files changed

+586
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], vara
179179
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
180180
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
181181
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
182+
paddle.fluid.layers.similarity_focus ArgSpec(args=['input', 'axis', 'indexes', 'name'], varargs=None, keywords=None, defaults=(None,))
182183
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
183184
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
184185
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/similarity_focus_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
class SimilarityFocusOpMaker : public framework::OpProtoAndCheckerMaker {
20+
public:
21+
void Make() override {
22+
AddInput("X",
23+
"(Tensor, default Tensor<float>), a 4-D tensor with shape,"
24+
" [BatchSize, X, Y, Z]");
25+
AddOutput("Out",
26+
"(Tensor, default Tensor<float>), the similarity focus mask"
27+
" with the same shape of input X.");
28+
AddAttr<int>("axis",
29+
"(int32), indicating the dimension to be select. It can"
30+
" only be 1, 2, or 3.");
31+
AddAttr<std::vector<int>>("indexes",
32+
"(std::vector<int32>), indicating the indexes"
33+
" of the selected dimension.");
34+
AddComment(R"DOC(
35+
SimilarityFocus Operator.
36+
37+
Generate a similarity focus mask with the same shape of input using the following method:
38+
1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding
39+
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
40+
it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X
41+
is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C).
42+
2. For each index, find the largest numbers in the tensor T, so that the same
43+
row and same column has at most one number(what it means is that if the
44+
largest number has been found in the i-th row and the j-th column, then
45+
the numbers in the i-th row or j-th column will be skipped. And then the
46+
next largest number will be selected from the remaining numbers. Obviously
47+
there will be min(B, C) numbers), and mark the corresponding position of the
48+
3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for
49+
each index.
50+
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
51+
52+
Refer to `Similarity Focus Layer <http://www.aclweb.org/anthology/N16-1108>`_
53+
)DOC");
54+
}
55+
};
56+
57+
class SimilarityFocusOp : public framework::OperatorWithKernel {
58+
public:
59+
using framework::OperatorWithKernel::OperatorWithKernel;
60+
61+
void InferShape(framework::InferShapeContext* ctx) const override {
62+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
63+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
64+
auto x_dims = ctx->GetInputDim("X");
65+
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "Input(X)'s rank should be 4.");
66+
ctx->SetOutputDim("Out", x_dims);
67+
ctx->ShareLoD("X", /*->*/ "Out");
68+
}
69+
70+
protected:
71+
framework::OpKernelType GetExpectedKernelType(
72+
const framework::ExecutionContext& ctx) const override {
73+
return framework::OpKernelType(
74+
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
75+
platform::CPUPlace());
76+
}
77+
};
78+
79+
} // namespace operators
80+
} // namespace paddle
81+
82+
namespace ops = paddle::operators;
83+
REGISTER_OPERATOR(similarity_focus, ops::SimilarityFocusOp,
84+
ops::SimilarityFocusOpMaker,
85+
paddle::framework::EmptyGradOpMaker);
86+
REGISTER_OP_CPU_KERNEL(similarity_focus, ops::SimilarityFocusKernel<float>,
87+
ops::SimilarityFocusKernel<double>);
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
#include <cstring>
19+
#include <utility>
20+
#include <vector>
21+
#include "paddle/fluid/framework/eigen.h"
22+
#include "paddle/fluid/framework/op_registry.h"
23+
24+
namespace paddle {
25+
namespace operators {
26+
using Tensor = framework::Tensor;
27+
28+
template <typename T>
29+
class SimilarityFocusKernel : public framework::OpKernel<T> {
30+
public:
31+
void Compute(const framework::ExecutionContext& context) const override {
32+
Tensor* out = context.Output<Tensor>("Out");
33+
const Tensor* x = context.Input<Tensor>("X");
34+
T* out_data = out->mutable_data<T>(context.GetPlace());
35+
const T* x_data = x->data<T>();
36+
37+
int axis = context.Attr<int>("axis");
38+
std::vector<int> indexes = context.Attr<std::vector<int>>("indexes");
39+
40+
int64_t batch_size = x->dims()[0];
41+
int64_t dim[4];
42+
for (int i = 1; i <= 3; ++i) {
43+
dim[i] = x->dims()[i];
44+
}
45+
46+
if (indexes.size() < 1) {
47+
PADDLE_THROW("Indexes' size can not be 0.");
48+
}
49+
for (auto index : indexes) {
50+
if (dim[axis] < index) {
51+
PADDLE_THROW("Index exceeds tensor shape limit.");
52+
}
53+
}
54+
55+
int64_t array_size = 1;
56+
for (int i = 1; i <= 3; ++i) {
57+
if (i != axis) {
58+
array_size *= dim[i];
59+
}
60+
}
61+
62+
std::vector<std::pair<T, int64_t>> array(array_size);
63+
64+
bool (*cmp)(std::pair<T, int64_t>, std::pair<T, int64_t>) = [](
65+
std::pair<T, int64_t> x, std::pair<T, int64_t> y) {
66+
return x.first > y.first;
67+
};
68+
69+
int64_t (*compute_index)(int64_t*, int, int, int, int) = [](
70+
int64_t* dim, int d1, int d2, int d3, int d4) {
71+
return d1 * dim[1] * dim[2] * dim[3] + d2 * dim[2] * dim[3] +
72+
d3 * dim[3] + d4;
73+
};
74+
75+
memset(out_data, 0, sizeof(T) * batch_size * dim[1] * dim[2] * dim[3]);
76+
for (int i = 0; i < batch_size; ++i) {
77+
for (auto index : indexes) {
78+
if (axis == 1) {
79+
for (int j = 0; j < dim[2]; ++j) {
80+
for (int k = 0; k < dim[3]; ++k) {
81+
array[j * dim[3] + k] = std::make_pair(
82+
x_data[compute_index(dim, i, index, j, k)], j * dim[3] + k);
83+
}
84+
}
85+
86+
std::sort(array.begin(), array.end(), cmp);
87+
int tag_num = 0;
88+
std::vector<bool> tag2(dim[2]), tag3(dim[3]);
89+
for (auto x : array) {
90+
int idx2 = x.second / dim[3];
91+
int idx3 = x.second % dim[3];
92+
if (tag2[idx2] || tag3[idx3]) {
93+
continue;
94+
}
95+
tag_num++;
96+
tag2[idx2] = true;
97+
tag3[idx3] = true;
98+
for (int j = 0; j < dim[1]; ++j) {
99+
out_data[compute_index(dim, i, j, idx2, idx3)] = 1;
100+
}
101+
if (tag_num == std::min(dim[2], dim[3])) {
102+
break;
103+
}
104+
}
105+
} else if (axis == 2) {
106+
for (int j = 0; j < dim[1]; ++j) {
107+
for (int k = 0; k < dim[3]; ++k) {
108+
array[j * dim[3] + k] = std::make_pair(
109+
x_data[compute_index(dim, i, j, index, k)], j * dim[3] + k);
110+
}
111+
}
112+
113+
std::sort(array.begin(), array.end(), cmp);
114+
int tag_num = 0;
115+
std::vector<bool> tag1(dim[1]), tag3(dim[3]);
116+
for (auto x : array) {
117+
int idx1 = x.second / dim[3];
118+
int idx3 = x.second % dim[3];
119+
if (tag1[idx1] || tag3[idx3]) {
120+
continue;
121+
}
122+
tag_num++;
123+
tag1[idx1] = true;
124+
tag3[idx3] = true;
125+
for (int j = 0; j < dim[2]; ++j) {
126+
out_data[compute_index(dim, i, idx1, j, idx3)] = 1;
127+
}
128+
if (tag_num == std::min(dim[1], dim[3])) {
129+
break;
130+
}
131+
}
132+
} else if (axis == 3) {
133+
for (int j = 0; j < dim[1]; ++j) {
134+
for (int k = 0; k < dim[2]; ++k) {
135+
array[j * dim[2] + k] = std::make_pair(
136+
x_data[compute_index(dim, i, j, k, index)], j * dim[2] + k);
137+
}
138+
}
139+
140+
std::sort(array.begin(), array.end(), cmp);
141+
int tag_num = 0;
142+
std::vector<bool> tag1(dim[1]), tag2(dim[2]);
143+
for (auto x : array) {
144+
int idx1 = x.second / dim[2];
145+
int idx2 = x.second % dim[2];
146+
if (tag1[idx1] || tag2[idx2]) {
147+
continue;
148+
}
149+
tag_num++;
150+
tag1[idx1] = true;
151+
tag2[idx2] = true;
152+
for (int j = 0; j < dim[3]; ++j) {
153+
out_data[compute_index(dim, i, idx1, idx2, j)] = 1;
154+
}
155+
if (tag_num == std::min(dim[1], dim[2])) {
156+
break;
157+
}
158+
}
159+
} else {
160+
PADDLE_THROW("Axis must be 1 or 2 or 3");
161+
}
162+
}
163+
}
164+
}
165+
};
166+
167+
} // namespace operators
168+
} // namespace paddle

python/paddle/fluid/layers/nn.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@
160160
'affine_grid',
161161
'sequence_reverse',
162162
'affine_channel',
163+
'similarity_focus',
163164
'hash',
164165
'grid_sampler',
165166
'log_loss',
@@ -7933,6 +7934,118 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
79337934
return out
79347935

79357936

7937+
def similarity_focus(input, axis, indexes, name=None):
7938+
"""
7939+
SimilarityFocus Operator
7940+
7941+
Generate a similarity focus mask with the same shape of input using the following method:
7942+
1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding
7943+
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
7944+
it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X
7945+
is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C).
7946+
2. For each index, find the largest numbers in the tensor T, so that the same
7947+
row and same column has at most one number(what it means is that if the
7948+
largest number has been found in the i-th row and the j-th column, then
7949+
the numbers in the i-th row or j-th column will be skipped. And then the
7950+
next largest number will be selected from the remaining numbers. Obviously
7951+
there will be min(B, C) numbers), and mark the corresponding position of the
7952+
3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for
7953+
each index.
7954+
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
7955+
7956+
Refer to `Similarity Focus Layer <http://www.aclweb.org/anthology/N16-1108>`_
7957+
7958+
.. code-block:: text
7959+
7960+
* Example :
7961+
7962+
Given a 4-D tensor x with the shape (BatchSize, C, A, B), where C is
7963+
the number of channels and the shape of feature map is (A, B):
7964+
x.shape = (2, 3, 2, 2)
7965+
x.data = [[[[0.8, 0.1],
7966+
[0.4, 0.5]],
7967+
7968+
[[0.9, 0.7],
7969+
[0.9, 0.9]],
7970+
7971+
[[0.8, 0.9],
7972+
[0.1, 0.2]]],
7973+
7974+
7975+
[[[0.2, 0.5],
7976+
[0.3, 0.4]],
7977+
7978+
[[0.9, 0.7],
7979+
[0.8, 0.4]],
7980+
7981+
[[0.0, 0.2],
7982+
[0.4, 0.7]]]]
7983+
7984+
Given axis: 1 (the axis of the channel)
7985+
Given indexes: [0]
7986+
7987+
then we get a 4-D tensor out with the same shape of input x:
7988+
out.shape = (2, 3, 2, 2)
7989+
out.data = [[[[1.0, 0.0],
7990+
[0.0, 1.0]],
7991+
7992+
[[1.0, 0.0],
7993+
[0.0, 1.0]],
7994+
7995+
[[1.0, 0.0],
7996+
[0.0, 1.0]]],
7997+
7998+
[[[0.0, 1.0],
7999+
[1.0, 0.0]],
8000+
8001+
[[0.0, 1.0],
8002+
[1.0, 0.0]],
8003+
8004+
[[0.0, 1.0],
8005+
[1.0, 0.0]]]]
8006+
8007+
Args:
8008+
input(Variable): The input tensor variable(default float). It should
8009+
be a 4-D tensor with shape [BatchSize, A, B, C].
8010+
axis(int): Indicating the dimension to be selected. It can only be
8011+
1, 2 or 3.
8012+
indexes(list): Indicating the indexes of the selected dimension.
8013+
8014+
Returns:
8015+
Variable: A tensor variable with the same shape and same type
8016+
as the input.
8017+
8018+
Examples:
8019+
.. code-block:: python
8020+
data = fluid.layers.data(
8021+
name='data', shape=[2, 3, 2, 2], dtype='float32')
8022+
x = fluid.layers.layer_norm(input=data, axis=1, indexes=[0])
8023+
"""
8024+
helper = LayerHelper('similarity_focus', **locals())
8025+
# check attrs
8026+
if isinstance(axis, int) is False:
8027+
raise TypeError("axis must be int type.")
8028+
if isinstance(indexes, list) is False:
8029+
raise TypeError("indexes must be list type.")
8030+
if axis != 1 and axis != 2 and axis != 3:
8031+
raise ValueError("axis must be 1, 2 or 3.")
8032+
if len(indexes) == 0:
8033+
raise ValueError("indexes can not be empty.")
8034+
8035+
if name is None:
8036+
out = helper.create_variable_for_type_inference(dtype=input.dtype)
8037+
else:
8038+
out = helper.create_variable(
8039+
name=name, dtype=input.dtype, persistable=False)
8040+
helper.append_op(
8041+
type='similarity_focus',
8042+
inputs={'X': input},
8043+
outputs={'Out': out},
8044+
attrs={"axis": axis,
8045+
"indexes": indexes})
8046+
return out
8047+
8048+
79368049
def hash(input, hash_size, num_hash=1, name=None):
79378050
"""
79388051
Hash the input to an integer whose value is less than the given hash size.

0 commit comments

Comments
 (0)