Skip to content

Commit 9b43ede

Browse files
committed
Polish arg_min_max_op
* Remove unused arg_max/min_op.h * Remove reference parameter. Use pointer insteaded. * undef macro * Always set OutT as int64_t.
1 parent 6d32e96 commit 9b43ede

File tree

7 files changed

+60
-95
lines changed

7 files changed

+60
-95
lines changed

paddle/fluid/operators/arg_max_op.cc

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/arg_max_op.h"
15+
#include "paddle/fluid/operators/arg_min_max_op_base.h"
1616

17-
REGISTER_OPERATOR(arg_max, paddle::operators::ArgMaxOp,
17+
REGISTER_OPERATOR(arg_max, paddle::operators::ArgMinMaxOp,
1818
paddle::operators::ArgMaxOpMaker,
1919
paddle::framework::EmptyGradOpMaker);
2020

2121
REGISTER_OP_CPU_KERNEL(
22-
arg_max, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
23-
float, int64_t>,
24-
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double,
22+
arg_max,
23+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, float>,
24+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double>,
25+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
2526
int64_t>,
26-
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int64_t,
27-
int64_t>,
28-
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int32_t,
29-
int64_t>,
30-
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int16_t,
31-
int64_t>,
32-
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t,
33-
int64_t>,
34-
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, uint8_t,
35-
int64_t>);
27+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
28+
int32_t>,
29+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
30+
int16_t>,
31+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
32+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
33+
uint8_t>);

paddle/fluid/operators/arg_max_op.cu

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/arg_max_op.h"
15+
#include "paddle/fluid/operators/arg_min_max_op_base.h"
1616

1717
REGISTER_OP_CUDA_KERNEL(
1818
arg_max,
19-
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float,
20-
int64_t>,
21-
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, double,
19+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float>,
20+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
21+
double>,
22+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
2223
int64_t>,
2324
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
24-
int64_t, int64_t>,
25+
int32_t>,
2526
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
26-
int32_t, int64_t>,
27+
int16_t>,
2728
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
28-
int16_t, int64_t>,
29-
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, size_t,
30-
int64_t>,
29+
size_t>,
3130
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
32-
uint8_t, int64_t>);
31+
uint8_t>);

paddle/fluid/operators/arg_max_op.h

Lines changed: 0 additions & 16 deletions
This file was deleted.

paddle/fluid/operators/arg_min_max_op_base.h

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include <string>
1617
#include <type_traits>
1718
#include <vector>
1819
#include "paddle/fluid/framework/ddim.h"
@@ -37,9 +38,9 @@ struct ArgMinMaxFunctor {};
3738
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
3839
enum_argminmax_value> { \
3940
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
40-
framework::LoDTensor& out, int64_t axis) { \
41+
framework::LoDTensor* out, int64_t axis) { \
4142
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
42-
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(out); \
43+
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
4344
out_eigen.device(*(ctx.eigen_device())) = \
4445
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
4546
} \
@@ -62,7 +63,7 @@ class ArgMinMaxKernel : public framework::OpKernel<T> {
6263
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
6364
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
6465
functor##rank; \
65-
functor##rank(dev_ctx, x, out, axis)
66+
functor##rank(dev_ctx, x, &out, axis)
6667

6768
switch (x.dims().size()) {
6869
case 1:
@@ -89,19 +90,20 @@ class ArgMinMaxKernel : public framework::OpKernel<T> {
8990
"than 6.",
9091
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
9192
break;
93+
#undef CALL_ARG_MINMAX_FUNCTOR
9294
}
9395
}
9496
};
9597

96-
template <typename DeviceContext, typename T, typename Tout>
98+
template <typename DeviceContext, typename T>
9799
using ArgMinKernel =
98-
ArgMinMaxKernel<DeviceContext, T, Tout, ArgMinMaxType::kArgMin>;
100+
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMin>;
99101

100-
template <typename DeviceContext, typename T, typename Tout>
102+
template <typename DeviceContext, typename T>
101103
using ArgMaxKernel =
102-
ArgMinMaxKernel<DeviceContext, T, Tout, ArgMinMaxType::kArgMax>;
104+
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMax>;
103105

104-
typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel {
106+
class ArgMinMaxOp : public framework::OperatorWithKernel {
105107
public:
106108
using framework::OperatorWithKernel::OperatorWithKernel;
107109

@@ -121,7 +123,7 @@ typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel {
121123
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
122124
ctx->SetOutputDim("Out", framework::make_ddim(vec));
123125
}
124-
} ArgMinOp, ArgMaxOp;
126+
};
125127

126128
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
127129
protected:
@@ -133,12 +135,13 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
133135
AddInput("X", "Input tensor.");
134136
AddOutput("Out", "Output tensor.");
135137
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
136-
AddComment(::paddle::string::Sprintf(R"DOC(
137-
%s Operator.
138+
AddComment(string::Sprintf(R"DOC(
139+
%s Operator.
138140
139-
Computes the indices of the %s elements of the input tensor's element along the provided axis.
141+
Computes the indices of the %s elements of the input tensor's element
142+
along the provided axis.
140143
)DOC",
141-
OpName(), Name()));
144+
OpName(), Name()));
142145
}
143146
};
144147

paddle/fluid/operators/arg_min_op.cc

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/arg_min_op.h"
15+
#include "paddle/fluid/operators/arg_min_max_op_base.h"
1616

17-
REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinOp,
17+
REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinMaxOp,
1818
paddle::operators::ArgMinOpMaker,
1919
paddle::framework::EmptyGradOpMaker);
2020

2121
REGISTER_OP_CPU_KERNEL(
22-
arg_min, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
23-
float, int64_t>,
24-
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double,
22+
arg_min,
23+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, float>,
24+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double>,
25+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
2526
int64_t>,
26-
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int64_t,
27-
int64_t>,
28-
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int32_t,
29-
int64_t>,
30-
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int16_t,
31-
int64_t>,
32-
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t,
33-
int64_t>,
34-
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, uint8_t,
35-
int64_t>);
27+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
28+
int32_t>,
29+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
30+
int16_t>,
31+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
32+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
33+
uint8_t>);

paddle/fluid/operators/arg_min_op.cu

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/arg_min_op.h"
15+
#include "paddle/fluid/operators/arg_min_max_op_base.h"
1616

1717
REGISTER_OP_CUDA_KERNEL(
1818
arg_min,
19-
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float,
20-
int64_t>,
21-
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, double,
19+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float>,
20+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
21+
double>,
22+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
2223
int64_t>,
2324
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
24-
int64_t, int64_t>,
25+
int32_t>,
2526
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
26-
int32_t, int64_t>,
27+
int16_t>,
2728
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
28-
int16_t, int64_t>,
29-
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, size_t,
30-
int64_t>,
29+
size_t>,
3130
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
32-
uint8_t, int64_t>);
31+
uint8_t>);

paddle/fluid/operators/arg_min_op.h

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)