Skip to content

Commit 1fe5acb

Browse files
authored
Expose sigmoid_cross_entropy_with_logits (#6147)
Also, change the `labels` to `label` for api consistency
1 parent 44e3914 commit 1fe5acb

File tree

5 files changed

+41
-29
lines changed

5 files changed

+41
-29
lines changed

paddle/operators/sigmoid_cross_entropy_with_logits_op.cc

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,19 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
2525

2626
void InferShape(framework::InferShapeContext* ctx) const override {
2727
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
28-
PADDLE_ENFORCE(ctx->HasInput("Labels"),
29-
"Input(Labels) should be not null.");
28+
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
3029
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
3130

3231
auto x_dims = ctx->GetInputDim("X");
33-
auto labels_dims = ctx->GetInputDim("Labels");
32+
auto labels_dims = ctx->GetInputDim("Label");
3433
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
3534
PADDLE_ENFORCE_EQ(labels_dims.size(), 2,
36-
"Input(Labels)'s rank should be 2.");
35+
"Input(Label)'s rank should be 2.");
3736
PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0],
38-
"The 1st dimension of Input(X) and Input(Labels) should "
37+
"The 1st dimension of Input(X) and Input(Label) should "
3938
"be equal.");
4039
PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1],
41-
"The 2nd dimension of Input(X) and Input(Labels) should "
40+
"The 2nd dimension of Input(X) and Input(Label) should "
4241
"be equal.");
4342

4443
ctx->SetOutputDim("Out", x_dims);
@@ -53,26 +52,25 @@ class SigmoidCrossEntropyWithLogitsGradOp
5352

5453
void InferShape(framework::InferShapeContext* ctx) const override {
5554
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
56-
PADDLE_ENFORCE(ctx->HasInput("Labels"),
57-
"Input(Labels) should be not null.");
55+
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
5856
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
5957
"Input(Out@GRAD) shoudl be not null.");
6058
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
6159
"Output(X@GRAD) should be not null.");
6260

6361
auto x_dims = ctx->GetInputDim("X");
64-
auto labels_dims = ctx->GetInputDim("Labels");
62+
auto labels_dims = ctx->GetInputDim("Label");
6563
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
6664
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
6765
PADDLE_ENFORCE_EQ(labels_dims.size(), 2,
68-
"Input(Labels)'s rank should be 2.");
66+
"Input(Label)'s rank should be 2.");
6967
PADDLE_ENFORCE_EQ(dout_dims.size(), 2,
7068
"Input(Out@Grad)'s rank should be 2.");
7169
PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0],
72-
"The 1st dimension of Input(X) and Input(Labels) should "
70+
"The 1st dimension of Input(X) and Input(Label) should "
7371
"be equal.");
7472
PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1],
75-
"The 2nd dimension of Input(X) and Input(Labels) should "
73+
"The 2nd dimension of Input(X) and Input(Label) should "
7674
"be equal.");
7775
PADDLE_ENFORCE_EQ(x_dims[0], dout_dims[0],
7876
"The 1st dimension of Input(X) and Input(Out@Grad) "
@@ -97,7 +95,7 @@ class SigmoidCrossEntropyWithLogitsOpMaker
9795
"This input is a tensor of logits computed by the previous "
9896
" operator. Logits are unscaled log probabilities given as "
9997
"log(p/(1-p)).");
100-
AddInput("Labels",
98+
AddInput("Label",
10199
"(Tensor, default Tensor<float>), a 2-D tensor of the same type "
102100
"and shape as X. This input is a tensor of probabalistic labels "
103101
"for each logit");

paddle/operators/sigmoid_cross_entropy_with_logits_op.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
2525
public:
2626
void Compute(const framework::ExecutionContext &context) const override {
2727
const framework::Tensor *X = context.Input<framework::Tensor>("X");
28-
const framework::Tensor *Labels =
29-
context.Input<framework::Tensor>("Labels");
28+
const framework::Tensor *Labels = context.Input<framework::Tensor>("Label");
3029
framework::Tensor *Out = context.Output<framework::Tensor>("Out");
3130
Out->mutable_data<T>(context.GetPlace());
3231

@@ -52,8 +51,7 @@ class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel<T> {
5251
public:
5352
void Compute(const framework::ExecutionContext &context) const override {
5453
const framework::Tensor *X = context.Input<framework::Tensor>("X");
55-
const framework::Tensor *Labels =
56-
context.Input<framework::Tensor>("Labels");
54+
const framework::Tensor *Labels = context.Input<framework::Tensor>("Label");
5755
const framework::Tensor *dOut =
5856
context.Input<framework::Tensor>(framework::GradVarName("Out"));
5957
framework::Tensor *dX =

python/paddle/v2/fluid/layers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ def func(**kwargs):
403403
_create_op_func_('scale')
404404
_create_op_func_('reshape')
405405
_create_op_func_('transpose')
406+
_create_op_func_('sigmoid_cross_entropy_with_logits')
406407

407408

408409
def cast(x, dtype, main_program=None):

python/paddle/v2/fluid/tests/test_layers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ def test_linear_chain_crf(self):
137137

138138
print(str(program))
139139

140+
def test_sigmoid_cross_entropy(self):
141+
program = Program()
142+
with program_guard(program):
143+
dat = layers.data(name='data', shape=[10], dtype='float32')
144+
lbl = layers.data(name='label', shape=[10], dtype='float32')
145+
self.assertIsNotNone(
146+
layers.sigmoid_cross_entropy_with_logits(
147+
x=dat, label=lbl))
148+
print(str(program))
149+
140150

141151
if __name__ == '__main__':
142152
unittest.main()

python/paddle/v2/fluid/tests/test_sigmoid_cross_entropy_with_logits_op.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from op_test import OpTest
33
from scipy.special import logit
44
from scipy.special import expit
5+
import unittest
56

67

78
class TestSigmoidCrossEntropyWithLogitsOp1(OpTest):
8-
'''Test sigmoid_cross_entropy_with_logit_op with binary labels
9-
'''
9+
"""Test sigmoid_cross_entropy_with_logit_op with binary label
10+
"""
1011

1112
def setUp(self):
1213
self.op_type = "sigmoid_cross_entropy_with_logits"
@@ -16,16 +17,16 @@ def setUp(self):
1617
'X': logit(
1718
np.random.uniform(0, 1, (batch_size, num_classes))
1819
.astype("float32")),
19-
'Labels': np.random.randint(0, 2, (batch_size, num_classes))
20+
'Label': np.random.randint(0, 2, (batch_size, num_classes))
2021
.astype("float32")
2122
}
2223

2324
# Fw Pass is implemented as elementwise sigmoid followed by
2425
# elementwise logistic loss
25-
# Labels * -log(sigmoid(X)) + (1 - labels) * -log(1 - sigmoid(X))
26+
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
2627
sigmoid_X = expit(self.inputs['X'])
27-
term1 = self.inputs['Labels'] * np.log(sigmoid_X)
28-
term2 = (1 - self.inputs['Labels']) * np.log(1 - sigmoid_X)
28+
term1 = self.inputs['Label'] * np.log(sigmoid_X)
29+
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
2930
self.outputs = {'Out': -term1 - term2}
3031

3132
def test_check_output(self):
@@ -36,8 +37,8 @@ def test_check_grad(self):
3637

3738

3839
class TestSigmoidCrossEntropyWithLogitsOp2(OpTest):
39-
'''Test sigmoid_cross_entropy_with_logit_op with probabalistic labels
40-
'''
40+
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label
41+
"""
4142

4243
def setUp(self):
4344
self.op_type = "sigmoid_cross_entropy_with_logits"
@@ -47,20 +48,24 @@ def setUp(self):
4748
'X': logit(
4849
np.random.uniform(0, 1, (batch_size, num_classes))
4950
.astype("float32")),
50-
'Labels': np.random.uniform(0, 1, (batch_size, num_classes))
51+
'Label': np.random.uniform(0, 1, (batch_size, num_classes))
5152
.astype("float32")
5253
}
5354

5455
# Fw Pass is implemented as elementwise sigmoid followed by
5556
# elementwise logistic loss
56-
# Labels * -log(sigmoid(X)) + (1 - labels) * -log(1 - sigmoid(X))
57+
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
5758
sigmoid_X = expit(self.inputs['X'])
58-
term1 = self.inputs['Labels'] * np.log(sigmoid_X)
59-
term2 = (1 - self.inputs['Labels']) * np.log(1 - sigmoid_X)
59+
term1 = self.inputs['Label'] * np.log(sigmoid_X)
60+
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
6061
self.outputs = {'Out': -term1 - term2}
6162

6263
def test_check_output(self):
6364
self.check_output()
6465

6566
def test_check_grad(self):
6667
self.check_grad(['X'], 'Out')
68+
69+
70+
if __name__ == '__main__':
71+
unittest.main()

0 commit comments

Comments
 (0)