Skip to content

Commit 1f2b8c1

Browse files
authored
add argsort op and delete unecessary file, test=develop (#5740)
fix linspace and argsort bugs, test=develop fix argsort and add 2 rank input reduce_max && reduce_min
1 parent d73b69b commit 1f2b8c1

14 files changed

+461
-3
lines changed

lite/backends/arm/math/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
126126
beam_search.cc
127127
reduce_max.cc
128128
reduce_min.cc
129+
reduce_max_min.cc
129130
sequence_pool.cc
130131
sequence_pool_grad.cc
131132
sequence_expand.cc

lite/backends/arm/math/funcs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
#include "lite/backends/arm/math/prior_box.h"
5555
#include "lite/backends/arm/math/quantize.h"
5656
#include "lite/backends/arm/math/reduce_max.h"
57+
#include "lite/backends/arm/math/reduce_max_min.h"
5758
#include "lite/backends/arm/math/reduce_mean.h"
5859
#include "lite/backends/arm/math/reduce_min.h"
5960
#include "lite/backends/arm/math/reduce_prod.h"
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "lite/backends/arm/math/reduce_max_min.h"
16+
#include <utility>
17+
#include <vector>
18+
#include "lite/backends/arm/math/funcs.h"
19+
#include "lite/core/tensor.h"
20+
21+
namespace paddle {
22+
namespace lite {
23+
namespace arm {
24+
namespace math {
25+
26+
template <>
27+
void reduce_second_of_two<float>(const float* src,
28+
float* dst,
29+
int first_in,
30+
int second_in,
31+
MaxMinType max_min_selector) {
32+
// max_min_selector == true, do reduce max; else do reduce min
33+
for (int j = 0; j < second_in; j++) {
34+
dst[j * first_in] = src[j * first_in];
35+
for (int k = 1; k < first_in; k++) {
36+
dst[j * first_in] = (src[j * first_in + k] <= dst[j * first_in]) ^
37+
static_cast<bool>(max_min_selector)
38+
? src[j * first_in + k]
39+
: dst[j * first_in];
40+
}
41+
}
42+
}
43+
44+
template <>
45+
void reduce_first_of_two<float>(const float* src,
46+
float* dst,
47+
int first_in,
48+
int second_in,
49+
MaxMinType max_min_selector) {
50+
// max_min_selector == true, do reduce max; else do reduce min
51+
for (int j = 0; j < first_in; j++) {
52+
dst[j] = src[j];
53+
for (int k = 1; k < second_in; k++) {
54+
dst[j] = (src[j + k * first_in] <= dst[j]) ^
55+
static_cast<bool>(max_min_selector)
56+
? src[j + k * first_in]
57+
: dst[j];
58+
}
59+
}
60+
}
61+
62+
} // namespace math
63+
} // namespace arm
64+
} // namespace lite
65+
} // namespace paddle
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
namespace paddle {
18+
namespace lite {
19+
namespace arm {
20+
namespace math {
21+
22+
enum class MaxMinType : bool { kMin = false, kMax = true };
23+
template <typename DataType>
24+
void reduce_first_of_two(const float* src,
25+
float* dst,
26+
int first_in,
27+
int second_in,
28+
MaxMinType compare_functor);
29+
30+
template <typename DataType>
31+
void reduce_second_of_two(const float* src,
32+
float* dst,
33+
int first_in,
34+
int second_in,
35+
MaxMinType max_min_selector);
36+
37+
} // namespace math
38+
} // namespace arm
39+
} // namespace lite
40+
} // namespace paddle

lite/kernels/arm/reduce_max_compute.cc

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
// limitations under the License.
1414

1515
#include "lite/kernels/arm/reduce_max_compute.h"
16+
1617
#include <string>
18+
1719
#include "lite/backends/arm/math/funcs.h"
1820

1921
namespace paddle {
@@ -104,9 +106,36 @@ void ReduceMaxCompute::Run() {
104106
} else {
105107
LOG(FATAL) << "dim's size over than 2, which is not supported now!!";
106108
}
109+
} else if (x_dims.size() == 2) {
110+
int first_in = x_dims[0];
111+
int second_in = x_dims[1];
112+
if (dim.size() == 1) {
113+
switch (dim[0]) {
114+
case 0:
115+
lite::arm::math::reduce_first_of_two<float>(
116+
input,
117+
output,
118+
first_in,
119+
second_in,
120+
lite::arm::math::MaxMinType::kMax);
121+
break;
122+
case 1:
123+
lite::arm::math::reduce_second_of_two<float>(
124+
input,
125+
output,
126+
first_in,
127+
second_in,
128+
lite::arm::math::MaxMinType::kMax);
129+
break;
130+
default:
131+
LOG(FATAL) << "error!!!";
132+
}
133+
} else {
134+
LOG(FATAL) << "dim's size over than 1, which is not supported now!!";
135+
} // x_dims == 2 && dim.size() == 1
107136
} else {
108-
LOG(FATAL) << "only support input with 3&4 dimensions now!!";
109-
}
137+
LOG(FATAL) << "only support input with 2&3&4 dimensions now!!";
138+
} // x_dims == 2
110139
}
111140

112141
} // namespace arm

lite/kernels/arm/reduce_min_compute.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,31 @@ void ReduceMinCompute::Run() {
104104
} else {
105105
LOG(FATAL) << "dim's size over than 2, which is not supported now!!";
106106
}
107+
} else if (x_dims.size() == 2) {
108+
int first_in = x_dims[0];
109+
int second_in = x_dims[1];
110+
if (dim.size() == 1) {
111+
switch (dim[0]) {
112+
case 0:
113+
lite::arm::math::reduce_first_of_two<float>(
114+
input,
115+
output,
116+
first_in,
117+
second_in,
118+
lite::arm::math::MaxMinType::kMin);
119+
break;
120+
case 1:
121+
lite::arm::math::reduce_second_of_two<float>(
122+
input,
123+
output,
124+
first_in,
125+
second_in,
126+
lite::arm::math::MaxMinType::kMin);
127+
break;
128+
default:
129+
LOG(FATAL) << "error!!!";
130+
}
131+
}
107132
} else {
108133
LOG(FATAL) << "only support input with 3&4 dimensions now!!";
109134
}

lite/kernels/host/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ add_kernel(fill_any_like_compute_host Host extra SRCS fill_any_like_compute.cc D
6464
add_kernel(meshgrid_compute_host Host extra SRCS meshgrid_compute.cc DEPS ${lite_kernel_deps})
6565
add_kernel(linspace_compute_host Host extra SRCS linspace_compute.cc DEPS ${lite_kernel_deps})
6666
add_kernel(tril_triu_compute_host Host extra SRCS tril_triu_compute.cc DEPS ${lite_kernel_deps})
67+
add_kernel(argsort Host extra SRCS argsort_compute.cc DEPS ${lite_kernel_deps})
6768

6869
if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
6970
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)

lite/kernels/host/argsort_compute.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/host/argsort_compute.h"
16+
17+
using argsort_fp32_compute = paddle::lite::kernels::host::ArgsortCompute<float>;
18+
REGISTER_LITE_KERNEL(
19+
argsort, kHost, kFloat, kAny, argsort_fp32_compute, argsort_fp32)
20+
.BindInput("X",
21+
{LiteType::GetTensorTy(TARGET(kHost),
22+
PRECISION(kFloat),
23+
DATALAYOUT(kAny))})
24+
.BindOutput("Indices",
25+
{LiteType::GetTensorTy(TARGET(kHost),
26+
PRECISION(kInt64),
27+
DATALAYOUT(kAny))})
28+
.BindOutput("Out",
29+
{LiteType::GetTensorTy(TARGET(kHost),
30+
PRECISION(kFloat),
31+
DATALAYOUT(kAny))})
32+
.Finalize();
33+
34+
using argsort_int32_compute =
35+
paddle::lite::kernels::host::ArgsortCompute<int32_t>;
36+
REGISTER_LITE_KERNEL(
37+
argsort, kHost, kFloat, kAny, argsort_int32_compute, argsort_int32)
38+
.BindInput("X",
39+
{LiteType::GetTensorTy(TARGET(kHost),
40+
PRECISION(kInt32),
41+
DATALAYOUT(kAny))})
42+
.BindOutput("Indices",
43+
{LiteType::GetTensorTy(TARGET(kHost),
44+
PRECISION(kInt64),
45+
DATALAYOUT(kAny))})
46+
.BindOutput("Out",
47+
{LiteType::GetTensorTy(TARGET(kHost),
48+
PRECISION(kInt32),
49+
DATALAYOUT(kAny))})
50+
.Finalize();
51+
52+
using argsort_int64_compute =
53+
paddle::lite::kernels::host::ArgsortCompute<int64_t>;
54+
REGISTER_LITE_KERNEL(
55+
argsort, kHost, kFloat, kAny, argsort_int64_compute, argsort_int64)
56+
.BindInput("X",
57+
{LiteType::GetTensorTy(TARGET(kHost),
58+
PRECISION(kInt64),
59+
DATALAYOUT(kAny))})
60+
.BindOutput("Indices",
61+
{LiteType::GetTensorTy(TARGET(kHost),
62+
PRECISION(kInt64),
63+
DATALAYOUT(kAny))})
64+
.BindOutput("Out",
65+
{LiteType::GetTensorTy(TARGET(kHost),
66+
PRECISION(kInt64),
67+
DATALAYOUT(kAny))})
68+
.Finalize();

lite/kernels/host/argsort_compute.h

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <algorithm>
17+
#include <utility>
18+
#include <vector>
19+
20+
#include "lite/core/kernel.h"
21+
#include "lite/core/op_registry.h"
22+
23+
namespace paddle {
24+
namespace lite {
25+
namespace kernels {
26+
namespace host {
27+
28+
template <typename DataType>
29+
class ArgsortCompute
30+
: public KernelLite<TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny)> {
31+
public:
32+
using param_t = operators::ArgsortParam;
33+
34+
void Run() {
35+
auto& param = Param<operators::ArgsortParam>();
36+
const DataType* x_data = param.X->template data<DataType>();
37+
DataType* out_val = param.Out->template mutable_data<DataType>();
38+
auto out_ind = param.Indices->template mutable_data<int64_t>();
39+
DDim x_dims = param.X->dims();
40+
int axis = param.axis;
41+
int dim_size = x_dims.size();
42+
bool descending = param.descending;
43+
if (axis < 0) {
44+
axis += dim_size;
45+
}
46+
47+
int outer_size = x_dims.count(0, axis);
48+
int axis_size = x_dims[axis];
49+
int inner_size = x_dims.count(axis + 1, dim_size);
50+
int sort_size = axis_size * inner_size;
51+
#pragma omp parallel for
52+
for (int n = 0; n < outer_size; n++) {
53+
const DataType* in_data = x_data + n * sort_size;
54+
DataType* out_data = out_val + n * sort_size;
55+
int64_t* out_ind_data = out_ind + n * sort_size;
56+
for (int i = 0; i < inner_size; i++) {
57+
std::vector<std::pair<DataType, int>> vec;
58+
vec.resize(axis_size);
59+
for (int j = 0; j < axis_size; j++) {
60+
vec[j] = std::make_pair(in_data[j * inner_size + i], j);
61+
}
62+
if (descending) {
63+
std::sort(vec.begin(),
64+
vec.end(),
65+
[](std::pair<DataType, int> a, std::pair<DataType, int> b) {
66+
return a.first > b.first;
67+
});
68+
} else {
69+
std::sort(vec.begin(),
70+
vec.end(),
71+
[](std::pair<DataType, int> a, std::pair<DataType, int> b) {
72+
return a.first < b.first;
73+
});
74+
}
75+
for (int j = 0; j < axis_size; j++) {
76+
out_data[j * inner_size + i] = vec[j].first;
77+
out_ind_data[j * inner_size + i] = vec[j].second;
78+
}
79+
}
80+
}
81+
}
82+
83+
virtual ~ArgsortCompute() = default;
84+
};
85+
86+
} // namespace host
87+
} // namespace kernels
88+
} // namespace lite
89+
} // namespace paddle

lite/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ add_operator(tile_op extra SRCS tile_op.cc DEPS ${op_DEPS})
146146
add_operator(meshgrid_op_lite extra SRCS meshgrid_op.cc DEPS ${op_DEPS})
147147
add_operator(linspace_op extra SRCS linspace_op.cc DEPS ${op_DEPS})
148148
add_operator(tril_triu_op extra SRCS tril_triu_op.cc DEPS ${op_DEPS})
149+
add_operator(argsort_op extra SRCS argsort_op.cc DEPS ${op_DEPS})
149150

150151
# for OCR specific
151152
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})

0 commit comments

Comments
 (0)