Skip to content

Commit 568c4e5

Browse files
committed
recommit using account sneaxiy
1 parent 145aaa4 commit 568c4e5

File tree

9 files changed

+483
-0
lines changed

9 files changed

+483
-0
lines changed

paddle/fluid/operators/arg_max_op.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/arg_max_op.h"
16+
/*
17+
REGISTER_ARG_MINMAX_OP_WITHOUT_GRADIENT(arg_max, ArgMax);
18+
19+
REGISTER_ARG_MINMAX_KERNEL(arg_max, ArgMax, CPU);
20+
*/
21+
22+
REGISTER_OPERATOR(arg_max, paddle::operators::ArgMaxOp,
23+
paddle::operators::ArgMaxOpMaker,
24+
paddle::framework::EmptyGradOpMaker);
25+
26+
REGISTER_OP_CPU_KERNEL(
27+
arg_max, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
28+
float, int64_t>,
29+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double,
30+
int64_t>,
31+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int64_t,
32+
int64_t>,
33+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int32_t,
34+
int64_t>,
35+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int16_t,
36+
int64_t>,
37+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t,
38+
int64_t>,
39+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, uint8_t,
40+
int64_t>);

paddle/fluid/operators/arg_max_op.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/arg_max_op.h"
16+
17+
// REGISTER_ARG_MINMAX_KERNEL(arg_max, ArgMax, CUDA);
18+
19+
REGISTER_OP_CUDA_KERNEL(
20+
arg_max,
21+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float,
22+
int64_t>,
23+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, double,
24+
int64_t>,
25+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
26+
int64_t, int64_t>,
27+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
28+
int32_t, int64_t>,
29+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
30+
int16_t, int64_t>,
31+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, size_t,
32+
int64_t>,
33+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
34+
uint8_t, int64_t>);

paddle/fluid/operators/arg_max_op.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/operators/arg_min_max_op_base.h"
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <type_traits>
17+
#include <vector>
18+
#include "paddle/fluid/framework/ddim.h"
19+
#include "paddle/fluid/framework/eigen.h"
20+
#include "paddle/fluid/framework/lod_tensor.h"
21+
#include "paddle/fluid/framework/op_registry.h"
22+
#include "paddle/fluid/framework/operator.h"
23+
#include "paddle/fluid/platform/enforce.h"
24+
#include "paddle/fluid/string/printf.h"
25+
26+
namespace paddle {
27+
namespace operators {
28+
29+
enum ArgMinMaxType { kArgMin, kArgMax };
30+
31+
template <typename DeviceContext, typename T, typename Tout, int64_t Rank,
32+
ArgMinMaxType argMinMaxValue>
33+
struct ArgMinMaxFunctor {};
34+
35+
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
36+
template <typename DeviceContext, typename T, typename Tout, int64_t Rank> \
37+
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
38+
enum_argminmax_value> { \
39+
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
40+
framework::LoDTensor& out, int64_t axis) { \
41+
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
42+
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(out); \
43+
out_eigen.device(*(ctx.eigen_device())) = \
44+
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
45+
} \
46+
}
47+
48+
DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
49+
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
50+
51+
template <typename DeviceContext, typename T, typename Tout,
52+
ArgMinMaxType EnumArgMinMaxValue>
53+
class ArgMinMaxKernel : public framework::OpKernel<T> {
54+
public:
55+
void Compute(const framework::ExecutionContext& ctx) const override {
56+
auto& x = *(ctx.Input<framework::LoDTensor>("X"));
57+
auto& out = *(ctx.Output<framework::LoDTensor>("Out"));
58+
out.mutable_data<Tout>(ctx.GetPlace());
59+
auto axis = ctx.Attr<int64_t>("axis");
60+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
61+
62+
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
63+
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
64+
functor##rank; \
65+
functor##rank(dev_ctx, x, out, axis)
66+
67+
switch (x.dims().size()) {
68+
case 1:
69+
CALL_ARG_MINMAX_FUNCTOR(1);
70+
break;
71+
case 2:
72+
CALL_ARG_MINMAX_FUNCTOR(2);
73+
break;
74+
case 3:
75+
CALL_ARG_MINMAX_FUNCTOR(3);
76+
break;
77+
case 4:
78+
CALL_ARG_MINMAX_FUNCTOR(4);
79+
break;
80+
case 5:
81+
CALL_ARG_MINMAX_FUNCTOR(5);
82+
break;
83+
case 6:
84+
CALL_ARG_MINMAX_FUNCTOR(6);
85+
break;
86+
default:
87+
PADDLE_THROW(
88+
"%s operator doesn't supports tensors whose ranks are greater "
89+
"than 6.",
90+
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
91+
break;
92+
}
93+
}
94+
};
95+
96+
template <typename DeviceContext, typename T, typename Tout>
97+
using ArgMinKernel =
98+
ArgMinMaxKernel<DeviceContext, T, Tout, ArgMinMaxType::kArgMin>;
99+
100+
template <typename DeviceContext, typename T, typename Tout>
101+
using ArgMaxKernel =
102+
ArgMinMaxKernel<DeviceContext, T, Tout, ArgMinMaxType::kArgMax>;
103+
104+
typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel {
105+
public:
106+
using framework::OperatorWithKernel::OperatorWithKernel;
107+
108+
void InferShape(framework::InferShapeContext* ctx) const override {
109+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
110+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
111+
const auto& x_dims = ctx->GetInputDim("X");
112+
int64_t axis = ctx->Attrs().Get<int64_t>("axis");
113+
PADDLE_ENFORCE(axis >= -x_dims.size() && axis < x_dims.size(),
114+
"'axis' must be inside [-Rank(X), Rank(X))");
115+
116+
auto x_rank = x_dims.size();
117+
if (axis < 0) axis += x_rank;
118+
119+
std::vector<int64_t> vec;
120+
for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]);
121+
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
122+
ctx->SetOutputDim("Out", framework::make_ddim(vec));
123+
}
124+
} ArgMinOp, ArgMaxOp;
125+
126+
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
127+
protected:
128+
virtual const char* OpName() const = 0;
129+
virtual const char* Name() const = 0;
130+
131+
public:
132+
void Make() override {
133+
AddInput("X", "Input tensor.");
134+
AddOutput("Out", "Output tensor.");
135+
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
136+
AddComment(::paddle::string::Sprintf(R"DOC(
137+
%s Operator.
138+
139+
Computes the indices of the %s elements of the input tensor's element along the provided axis.
140+
)DOC",
141+
OpName(), Name()));
142+
}
143+
};
144+
145+
class ArgMinOpMaker : public BaseArgMinMaxOpMaker {
146+
protected:
147+
const char* OpName() const override { return "ArgMin"; }
148+
const char* Name() const override { return "min"; }
149+
};
150+
151+
class ArgMaxOpMaker : public BaseArgMinMaxOpMaker {
152+
protected:
153+
const char* OpName() const override { return "ArgMax"; }
154+
const char* Name() const override { return "max"; }
155+
};
156+
} // namespace operators
157+
} // namespace paddle

paddle/fluid/operators/arg_min_op.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/arg_min_op.h"
16+
/*
17+
REGISTER_ARG_MINMAX_OP_WITHOUT_GRADIENT(arg_min, ArgMin);
18+
19+
REGISTER_ARG_MINMAX_KERNEL(arg_min, ArgMin, CPU);
20+
*/
21+
22+
REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinOp,
23+
paddle::operators::ArgMinOpMaker,
24+
paddle::framework::EmptyGradOpMaker);
25+
26+
REGISTER_OP_CPU_KERNEL(
27+
arg_min, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
28+
float, int64_t>,
29+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double,
30+
int64_t>,
31+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int64_t,
32+
int64_t>,
33+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int32_t,
34+
int64_t>,
35+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int16_t,
36+
int64_t>,
37+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t,
38+
int64_t>,
39+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, uint8_t,
40+
int64_t>);

paddle/fluid/operators/arg_min_op.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/arg_min_op.h"
16+
17+
// REGISTER_ARG_MINMAX_KERNEL(arg_min, ArgMin, CUDA);
18+
19+
REGISTER_OP_CUDA_KERNEL(
20+
arg_min,
21+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float,
22+
int64_t>,
23+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, double,
24+
int64_t>,
25+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
26+
int64_t, int64_t>,
27+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
28+
int32_t, int64_t>,
29+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
30+
int16_t, int64_t>,
31+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, size_t,
32+
int64_t>,
33+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
34+
uint8_t, int64_t>);

paddle/fluid/operators/arg_min_op.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/operators/arg_min_max_op_base.h"

0 commit comments

Comments
 (0)