Skip to content

Commit f3ffd75

Browse files
authored
add trace op and complex API of trace & sum, test=release/2.0 (#24197)
1 parent cf02888 commit f3ffd75

File tree

8 files changed

+881
-2
lines changed

8 files changed

+881
-2
lines changed

paddle/fluid/operators/trace_op.cc

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/trace_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class TraceOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext *ctx) const override {
25+
PADDLE_ENFORCE_EQ(
26+
ctx->HasInput("Input"), true,
27+
platform::errors::NotFound("Input of TraceOp is not found."));
28+
29+
PADDLE_ENFORCE_EQ(
30+
ctx->HasOutput("Out"), true,
31+
platform::errors::NotFound("Output of TraceOp is not found."));
32+
33+
int dim1 = ctx->Attrs().Get<int>("dim1");
34+
int dim2 = ctx->Attrs().Get<int>("dim2");
35+
36+
auto x_dims = ctx->GetInputDim("Input");
37+
38+
int dim1_ = dim1 < 0 ? x_dims.size() + dim1 : dim1;
39+
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 : dim2;
40+
41+
PADDLE_ENFORCE_GE(
42+
x_dims.size(), 2,
43+
platform::errors::OutOfRange(
44+
"trace requires an tensor of at least two dimensions"));
45+
PADDLE_ENFORCE_LT(
46+
dim1_, x_dims.size(),
47+
platform::errors::OutOfRange(
48+
"Attr(dim1) is out of range (expected to be in range of [%ld, "
49+
"%ld], but got %ld).",
50+
-(x_dims.size()), (x_dims.size() - 1), dim1));
51+
PADDLE_ENFORCE_LT(
52+
dim2_, x_dims.size(),
53+
platform::errors::OutOfRange(
54+
"Attr(dim2) is out of range (expected to be in range of [%ld, "
55+
"%ld], but got %ld).",
56+
-(x_dims.size()), (x_dims.size() - 1), dim2));
57+
PADDLE_ENFORCE_NE(dim1_, dim2_,
58+
platform::errors::InvalidArgument(
59+
"The dimensions should not be identical "
60+
"%ld vs %ld.",
61+
dim1, dim2));
62+
63+
auto sizes = vectorize(x_dims);
64+
if (x_dims.size() == 2) {
65+
sizes.clear();
66+
sizes.push_back(1);
67+
} else {
68+
sizes.erase(sizes.begin() + std::max(dim1_, dim2_));
69+
sizes.erase(sizes.begin() + std::min(dim1_, dim2_));
70+
}
71+
ctx->SetOutputDim("Out", framework::make_ddim(sizes));
72+
}
73+
};
74+
75+
class TraceOpMaker : public framework::OpProtoAndCheckerMaker {
76+
public:
77+
void Make() override {
78+
AddInput("Input",
79+
"(Tensor) The input tensor, from which the diagonals are taken.");
80+
AddOutput("Out", "(Tensor) the sum along diagonals of the input tensor");
81+
AddAttr<int>(
82+
"offset",
83+
R"DOC((int, default 0), offset of the diagonal from the main diagonal. Can be both positive and negative. Defaults to 0.
84+
)DOC")
85+
.SetDefault(0);
86+
AddAttr<int>(
87+
"dim1",
88+
R"DOC((int, default 0), the first dim of the 2-D planes from which the diagonals should be taken.
89+
Can be both positive and negative. Default: 0.
90+
)DOC")
91+
.SetDefault(-2);
92+
AddAttr<int>(
93+
"dim2",
94+
R"DOC((int, default 1), the second dim of the 2-D planes from which the diagonals should be taken.
95+
Can be both positive and negative. Default: 1.
96+
)DOC")
97+
.SetDefault(-1);
98+
AddComment(R"DOC(
99+
Trace Operator.
100+
Return the sum along diagonals of the input tensor.
101+
The behavior of this operator is similar to how `numpy.trace` works.
102+
103+
If Input is 2-D, returns the sum of diagonal.
104+
If Input has larger dimensions, then returns an tensor of diagonals sum, diagonals be taken from
105+
the 2-D planes specified by dim1 and dim2.
106+
107+
)DOC");
108+
}
109+
};
110+
class TraceOpGrad : public framework::OperatorWithKernel {
111+
public:
112+
using framework::OperatorWithKernel::OperatorWithKernel;
113+
114+
void InferShape(framework::InferShapeContext *ctx) const override {
115+
PADDLE_ENFORCE_EQ(
116+
ctx->HasInput("Input"), true,
117+
platform::errors::NotFound("Input(Input) of TraceOp is not found."));
118+
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Input")), true,
119+
platform::errors::NotFound(
120+
"Output(Input@GRAD) of TraceGradOp is not found."));
121+
ctx->SetOutputDim(framework::GradVarName("Input"),
122+
ctx->GetInputDim("Input"));
123+
}
124+
125+
protected:
126+
framework::OpKernelType GetExpectedKernelType(
127+
const framework::ExecutionContext &ctx) const override {
128+
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
129+
ctx, framework::GradVarName("Out")),
130+
ctx.GetPlace());
131+
}
132+
};
133+
134+
template <typename T>
135+
class TraceGradOpMaker : public framework::SingleGradOpMaker<T> {
136+
public:
137+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
138+
139+
protected:
140+
void Apply(GradOpPtr<T> grad_op) const override {
141+
grad_op->SetType("trace_grad");
142+
grad_op->SetInput("Input", this->Input("Input"));
143+
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
144+
grad_op->SetOutput(framework::GradVarName("Input"),
145+
this->InputGrad("Input"));
146+
grad_op->SetAttrMap(this->Attrs());
147+
}
148+
};
149+
150+
DECLARE_NO_NEED_BUFFER_VARS_INFERER(TraceGradNoNeedBufferVarsInference,
151+
"Input");
152+
153+
} // namespace operators
154+
} // namespace paddle
155+
156+
namespace ops = paddle::operators;
157+
REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker,
158+
ops::TraceGradOpMaker<paddle::framework::OpDesc>,
159+
ops::TraceGradOpMaker<paddle::imperative::OpBase>);
160+
161+
REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad,
162+
ops::TraceGradNoNeedBufferVarsInference);
163+
REGISTER_OP_CPU_KERNEL(
164+
trace, ops::TraceKernel<paddle::platform::CPUDeviceContext, int>,
165+
ops::TraceKernel<paddle::platform::CPUDeviceContext, float>,
166+
ops::TraceKernel<paddle::platform::CPUDeviceContext, double>,
167+
ops::TraceKernel<paddle::platform::CPUDeviceContext, int64_t>);
168+
REGISTER_OP_CPU_KERNEL(
169+
trace_grad, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int>,
170+
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, float>,
171+
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, double>,
172+
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/fluid/operators/trace_op.cu

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
16+
#include "paddle/fluid/operators/trace_op.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
template <typename T>
22+
struct IdentityFunctor {
23+
HOSTDEVICE explicit inline IdentityFunctor() {}
24+
25+
HOSTDEVICE inline T operator()(const T& x) const { return x; }
26+
};
27+
28+
template <typename DeviceContext, typename T>
29+
class TraceCUDAKernel : public framework::OpKernel<T> {
30+
public:
31+
void Compute(const framework::ExecutionContext& context) const override {
32+
auto* input = context.Input<framework::Tensor>("Input");
33+
auto* out = context.Output<framework::Tensor>("Out");
34+
35+
const int64_t offset = context.Attr<int>("offset");
36+
const int64_t dim1 = context.Attr<int>("dim1");
37+
const int64_t dim2 = context.Attr<int>("dim2");
38+
39+
T* out_data = out->mutable_data<T>(context.GetPlace());
40+
const framework::Tensor diag =
41+
Diagonal<DeviceContext, T>(context, input, offset, dim1, dim2);
42+
if (diag.numel() > 0) {
43+
auto stream = context.cuda_device_context().stream();
44+
std::vector<int> reduce_dims;
45+
reduce_dims.push_back(out->dims().size());
46+
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
47+
diag, out, reduce_dims, static_cast<T>(0), cub::Sum(),
48+
IdentityFunctor<T>(), stream);
49+
}
50+
}
51+
};
52+
} // namespace operators
53+
} // namespace paddle
54+
55+
namespace ops = paddle::operators;
56+
namespace platform = paddle::platform;
57+
REGISTER_OP_CUDA_KERNEL(
58+
trace, ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, int>,
59+
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
60+
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
61+
platform::float16>,
62+
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, float>,
63+
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, double>);
64+
REGISTER_OP_CUDA_KERNEL(
65+
trace_grad, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int>,
66+
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
67+
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
68+
platform::float16>,
69+
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, float>,
70+
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, double>);

0 commit comments

Comments
 (0)