Skip to content

Commit 3d939d3

Browse files
authored
Merge pull request #16023 from heavengate/kl_div_loss
KL div loss: add kldiv_loss op
2 parents 5447463 + e56fd43 commit 3d939d3

File tree

7 files changed

+437
-0
lines changed

7 files changed

+437
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func',
230230
paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '1546136806fef5c08f6918544bd9151d'))
231231
paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', '2f6ff96864054a31aa4bb659c6722c99'))
232232
paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '431a4301c35032166ec029f7432c80a7'))
233+
paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '776d536cac47c89073abc7ee524d5aec'))
233234
paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607'))
234235
paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329'))
235236
paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393'))
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/kldiv_loss_op.h"
13+
#include <memory>
14+
#include <string>
15+
#include "paddle/fluid/framework/op_registry.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using framework::Tensor;
21+
22+
class KLDivLossOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
void InferShape(framework::InferShapeContext* ctx) const override {
26+
PADDLE_ENFORCE(ctx->HasInput("X"),
27+
"Input(X) of KLDivLossOp should not be null.");
28+
PADDLE_ENFORCE(ctx->HasInput("Target"),
29+
"Input(Target) of KLDivLossOp should not be null.");
30+
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
31+
"Output(Loss) of KLDivLossOp should not be null.");
32+
33+
auto dim_x = ctx->GetInputDim("X");
34+
auto dim_target = ctx->GetInputDim("Target");
35+
PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(),
36+
"Input(X) rank and Input(Target) rank should be same.");
37+
for (int i = 0; i < dim_x.size(); i++) {
38+
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i],
39+
"Input(X) and Input(Target) should in same shape.");
40+
}
41+
42+
auto reduction = ctx->Attrs().Get<std::string>("reduction");
43+
44+
PADDLE_ENFORCE(
45+
"mean" == reduction || "sum" == reduction || "batchmean" == reduction ||
46+
"none" == reduction,
47+
"Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'.");
48+
49+
if ("none" == reduction) {
50+
ctx->SetOutputDim("Loss", dim_x);
51+
} else {
52+
ctx->SetOutputDim("Loss", {1});
53+
}
54+
}
55+
56+
protected:
57+
framework::OpKernelType GetExpectedKernelType(
58+
const framework::ExecutionContext& ctx) const override {
59+
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
60+
ctx.GetPlace());
61+
}
62+
};
63+
64+
class KLDivLossOpMaker : public framework::OpProtoAndCheckerMaker {
65+
public:
66+
void Make() override {
67+
AddInput("X",
68+
"The input tensor of KL divergence loss operator. "
69+
"This is a tensor with shape of [N, *], where N is the "
70+
"batch size, * means any number of additional dimensions.");
71+
AddInput("Target",
72+
"The tensor of KL divergence loss operator. "
73+
"This is a tensor with shape of Input(X).");
74+
AddOutput(
75+
"Loss",
76+
"The output KL divergence loss tensor. if Attr(reduction) is "
77+
"'none', this tensor should be in same shape of of Input(X), else "
78+
"this tensor should be in shape of [1].");
79+
80+
AddAttr<std::string>(
81+
"reduction",
82+
"The reduction type to apply to the output, available types "
83+
"are 'none' | 'batchmean' | 'mean' | 'sum', 'none' for no "
84+
"reduction, 'batchmean' for the sum of output divided by "
85+
"batch size, 'mean' for the average value of all output, "
86+
"'sum' for the sum of the output.")
87+
.SetDefault("mean");
88+
89+
AddComment(R"DOC(
90+
This operator calculates the Kullback-Leibler divergence loss
91+
between Input(X) and Input(Target).
92+
93+
KL divergence loss is calculated as follows:
94+
95+
$$l(x, y) = y * (\log(y) - x)$$
96+
97+
While :math:`x` is Input(X) and :math:`y` is Input(Target).
98+
99+
While :attr:`reduction` is :attr:`none`, output loss is in
100+
the same shape as Input(X), loss in each point is calculated
101+
seperately and no reduction is applied.
102+
103+
While :attr:`reduction` is :attr:`mean`, output loss is in
104+
shape of [1] and loss value is the mean value of all losses.
105+
106+
While :attr:`reduction` is :attr:`sum`, output loss is in
107+
shape of [1] and loss value is the sum value of all losses.
108+
109+
While :attr:`reduction` is :attr:`batchmean`, output loss is
110+
in shape of [1] and loss value is the sum value of all losses
111+
divided by batch size.
112+
113+
)DOC");
114+
}
115+
};
116+
117+
class KLDivLossOpGrad : public framework::OperatorWithKernel {
118+
public:
119+
using framework::OperatorWithKernel::OperatorWithKernel;
120+
void InferShape(framework::InferShapeContext* ctx) const override {
121+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
122+
PADDLE_ENFORCE(ctx->HasInput("Target"), "Input(Target) should not be null");
123+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
124+
"Input(Loss@GRAD) should not be null");
125+
auto dim_x = ctx->GetInputDim("X");
126+
if (ctx->HasOutput(framework::GradVarName("X"))) {
127+
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
128+
}
129+
}
130+
131+
protected:
132+
framework::OpKernelType GetExpectedKernelType(
133+
const framework::ExecutionContext& ctx) const override {
134+
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
135+
ctx.GetPlace());
136+
}
137+
};
138+
139+
class KLDivLossOpGradMaker : public framework::SingleGradOpDescMaker {
140+
public:
141+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
142+
143+
protected:
144+
std::unique_ptr<framework::OpDesc> Apply() const override {
145+
auto* op = new framework::OpDesc();
146+
op->SetType("kldiv_loss_grad");
147+
op->SetInput("X", Input("X"));
148+
op->SetInput("Target", Input("Target"));
149+
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
150+
151+
op->SetAttrMap(Attrs());
152+
153+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
154+
return std::unique_ptr<framework::OpDesc>(op);
155+
}
156+
};
157+
158+
} // namespace operators
159+
} // namespace paddle
160+
161+
namespace ops = paddle::operators;
162+
REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker,
163+
ops::KLDivLossOpGradMaker);
164+
REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad);
165+
REGISTER_OP_CPU_KERNEL(
166+
kldiv_loss, ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, float>,
167+
ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, double>);
168+
REGISTER_OP_CPU_KERNEL(
169+
kldiv_loss_grad,
170+
ops::KLDivLossGradKernel<paddle::platform::CPUDeviceContext, float>,
171+
ops::KLDivLossGradKernel<paddle::platform::CPUDeviceContext, double>);
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
#include "paddle/fluid/operators/kldiv_loss_op.h"
12+
13+
namespace ops = paddle::operators;
14+
namespace plat = paddle::platform;
15+
REGISTER_OP_CUDA_KERNEL(
16+
kldiv_loss,
17+
ops::KLDivLossKernel<paddle::platform::CUDADeviceContext, float>,
18+
ops::KLDivLossKernel<paddle::platform::CUDADeviceContext, double>);
19+
REGISTER_OP_CUDA_KERNEL(
20+
kldiv_loss_grad,
21+
ops::KLDivLossGradKernel<paddle::platform::CUDADeviceContext, float>,
22+
ops::KLDivLossGradKernel<paddle::platform::CUDADeviceContext, double>);
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#pragma once
13+
#include <string>
14+
#include "paddle/fluid/framework/eigen.h"
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/platform/hostdevice.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using Tensor = framework::Tensor;
22+
template <typename T, int MajorType = Eigen::RowMajor,
23+
typename IndexType = Eigen::DenseIndex>
24+
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
25+
26+
using Array1 = Eigen::DSizes<int64_t, 1>;
27+
28+
template <typename T>
29+
struct KLDivLossForward {
30+
HOSTDEVICE KLDivLossForward() {}
31+
32+
HOSTDEVICE T operator()(const T& target, const T& input) const {
33+
if (target <= 0) {
34+
return 0;
35+
} else {
36+
return target * (std::log(target) - input);
37+
}
38+
}
39+
};
40+
41+
template <typename T>
42+
struct KLDivLossBackward {
43+
HOSTDEVICE KLDivLossBackward() {}
44+
45+
HOSTDEVICE T operator()(const T& target, const T& grad) const {
46+
if (target <= 0) {
47+
return 0;
48+
} else {
49+
return static_cast<T>(-1.) * grad;
50+
}
51+
}
52+
};
53+
54+
template <typename DeviceContext, typename T>
55+
class KLDivLossKernel : public framework::OpKernel<T> {
56+
public:
57+
void Compute(const framework::ExecutionContext& ctx) const override {
58+
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
59+
auto* input = ctx.Input<Tensor>("X");
60+
auto* target = ctx.Input<Tensor>("Target");
61+
auto* loss = ctx.Output<Tensor>("Loss");
62+
auto reduction = ctx.Attr<std::string>("reduction");
63+
64+
const int n = input->dims()[0];
65+
66+
loss->mutable_data<T>(ctx.GetPlace());
67+
auto input_t = EigenVector<T>::Flatten(*input);
68+
auto target_t = EigenVector<T>::Flatten(*target);
69+
auto loss_t = EigenVector<T>::Flatten(*loss);
70+
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>());
71+
if ("none" == reduction) {
72+
loss_t.device(place) = output;
73+
} else if ("batchmean" == reduction) {
74+
auto output_sum = output.sum().eval();
75+
loss_t.device(place) = output_sum / output_sum.constant(n);
76+
} else if ("mean" == reduction) {
77+
loss_t.device(place) = output.mean();
78+
} else if ("sum" == reduction) {
79+
loss_t.device(place) = output.sum();
80+
}
81+
}
82+
};
83+
84+
template <typename DeviceContext, typename T>
85+
class KLDivLossGradKernel : public framework::OpKernel<T> {
86+
public:
87+
void Compute(const framework::ExecutionContext& ctx) const override {
88+
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
89+
auto* target = ctx.Input<Tensor>("Target");
90+
auto reduction = ctx.Attr<std::string>("reduction");
91+
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
92+
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
93+
94+
const int n = input_grad->dims()[0];
95+
const int numel = input_grad->numel();
96+
const int expand = numel / loss_grad->numel();
97+
98+
input_grad->mutable_data<T>(ctx.GetPlace());
99+
100+
auto target_t = EigenVector<T>::Flatten(*target);
101+
102+
auto input_grad_t = EigenVector<T>::Flatten(*input_grad);
103+
auto loss_grad_t = EigenVector<T>::Flatten(*loss_grad);
104+
105+
auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand));
106+
auto grad_t = target_t * loss_grad_expand;
107+
input_grad_t.device(place) =
108+
target_t.binaryExpr(grad_t, KLDivLossBackward<T>());
109+
110+
if ("mean" == reduction) {
111+
input_grad_t.device(place) = input_grad_t / static_cast<T>(numel);
112+
} else if ("batchmean" == reduction) {
113+
input_grad_t.device(place) = input_grad_t / static_cast<T>(n);
114+
}
115+
}
116+
};
117+
118+
} // namespace operators
119+
} // namespace paddle

python/paddle/fluid/layers/nn.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@
188188
'psroi_pool',
189189
'teacher_student_sigmoid_loss',
190190
'huber_loss',
191+
'kldiv_loss',
191192
'tree_conv',
192193
'npair_loss',
193194
'fsp_matrix',
@@ -10762,6 +10763,38 @@ def huber_loss(input, label, delta):
1076210763
return out
1076310764

1076410765

10766+
@templatedoc()
10767+
def kldiv_loss(x, target, reduction='mean', name=None):
10768+
"""
10769+
${comment}
10770+
10771+
Args:
10772+
x (Variable): ${x_comment}
10773+
target (Variable): ${target_comment}
10774+
reduction (Variable): ${reduction_comment}
10775+
name (str, default None): The name of this layer.
10776+
10777+
Returns:
10778+
kldiv\_loss (Variable): The KL divergence loss.
10779+
10780+
Examples:
10781+
.. code-block:: python
10782+
10783+
x = fluid.layers.data(name='x', shape=[4,2,2], dtype='float32')
10784+
target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32')
10785+
loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='batchmean')
10786+
"""
10787+
helper = LayerHelper('kldiv_loss', **locals())
10788+
loss = helper.create_variable_for_type_inference(dtype=x.dtype)
10789+
helper.append_op(
10790+
type='kldiv_loss',
10791+
inputs={'X': x,
10792+
'Target': target},
10793+
outputs={'Loss': loss},
10794+
attrs={'reduction': reduction})
10795+
return loss
10796+
10797+
1076510798
@templatedoc()
1076610799
def tree_conv(nodes_vector,
1076710800
edge_set,

0 commit comments

Comments
 (0)