Skip to content

Commit 1d28a49

Browse files
authored
Merge pull request #16879 from phlrain/pick_sigmoid_cross
Merge pull request #16799 from phlrain/sigmoid_corss_entropy_support_…
2 parents 7e33f5a + 19f3839 commit 1d28a49

File tree

2 files changed

+127
-26
lines changed

2 files changed

+127
-26
lines changed

paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,22 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
3131

3232
auto x_dims = ctx->GetInputDim("X");
3333
auto labels_dims = ctx->GetInputDim("Label");
34-
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
35-
PADDLE_ENFORCE_EQ(labels_dims.size(), 2,
36-
"Input(Label)'s rank should be 2.");
37-
PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0],
38-
"The 1st dimension of Input(X) and Input(Label) should "
39-
"be equal.");
40-
PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1],
41-
"The 2nd dimension of Input(X) and Input(Label) should "
42-
"be equal.");
34+
35+
int rank = x_dims.size();
36+
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
37+
"Input(X) and Input(Label) shall have the same rank.");
38+
bool check = true;
39+
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
40+
framework::product(labels_dims) <= 0)) {
41+
check = false;
42+
}
43+
44+
if (check) {
45+
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank),
46+
framework::slice_ddim(labels_dims, 0, rank),
47+
"Input(X) and Input(Label) shall have the same shape "
48+
"except the last dimension.");
49+
}
4350

4451
ctx->ShareDim("X", /*->*/ "Out");
4552
ctx->ShareLoD("X", /*->*/ "Out");
@@ -62,23 +69,24 @@ class SigmoidCrossEntropyWithLogitsGradOp
6269
auto x_dims = ctx->GetInputDim("X");
6370
auto labels_dims = ctx->GetInputDim("Label");
6471
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
65-
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
66-
PADDLE_ENFORCE_EQ(labels_dims.size(), 2,
67-
"Input(Label)'s rank should be 2.");
68-
PADDLE_ENFORCE_EQ(dout_dims.size(), 2,
69-
"Input(Out@Grad)'s rank should be 2.");
70-
PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0],
71-
"The 1st dimension of Input(X) and Input(Label) should "
72-
"be equal.");
73-
PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1],
74-
"The 2nd dimension of Input(X) and Input(Label) should "
75-
"be equal.");
76-
PADDLE_ENFORCE_EQ(x_dims[0], dout_dims[0],
77-
"The 1st dimension of Input(X) and Input(Out@Grad) "
78-
"should be equal.");
79-
PADDLE_ENFORCE_EQ(x_dims[1], dout_dims[1],
80-
"The 2nd dimension of Input(X) and Input(Out@Grad) "
81-
"should be equal.");
72+
73+
int rank = x_dims.size();
74+
bool check = true;
75+
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
76+
framework::product(labels_dims) <= 0)) {
77+
check = false;
78+
}
79+
80+
if (check) {
81+
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank),
82+
framework::slice_ddim(labels_dims, 0, rank),
83+
"Input(X) and Input(Label) shall have the same shape.");
84+
85+
PADDLE_ENFORCE_EQ(
86+
framework::slice_ddim(x_dims, 0, rank),
87+
framework::slice_ddim(dout_dims, 0, rank),
88+
"Input(X) and Input(Out@Grad) shall have the same shape.");
89+
}
8290

8391
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
8492
}

python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,98 @@ def test_check_grad(self):
149149
self.check_grad(['X'], 'Out')
150150

151151

152+
class TestSigmoidCrossEntropyWithLogitsOp5(OpTest):
153+
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label
154+
"""
155+
156+
def setUp(self):
157+
self.op_type = "sigmoid_cross_entropy_with_logits"
158+
batch_size = [10, 10]
159+
num_classes = 20
160+
self.inputs = {
161+
'X': logit(
162+
np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
163+
.astype("float32")),
164+
'Label': np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
165+
.astype("float32")
166+
}
167+
168+
# Fw Pass is implemented as elementwise sigmoid followed by
169+
# elementwise logistic loss
170+
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
171+
sigmoid_X = expit(self.inputs['X'])
172+
term1 = self.inputs['Label'] * np.log(sigmoid_X)
173+
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
174+
self.outputs = {'Out': -term1 - term2}
175+
176+
def test_check_output(self):
177+
self.check_output()
178+
179+
def test_check_grad(self):
180+
self.check_grad(['X'], 'Out')
181+
182+
183+
class TestSigmoidCrossEntropyWithNorm2(OpTest):
184+
def setUp(self):
185+
self.op_type = "sigmoid_cross_entropy_with_logits"
186+
batch_size = [10, 10]
187+
num_classes = 20
188+
ignore_index = -1
189+
self.inputs = {
190+
'X': logit(
191+
np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
192+
.astype("float32")),
193+
'Label': np.random.randint(-1, 2, tuple(batch_size + [num_classes]))
194+
.astype("float32")
195+
}
196+
self.attrs = {'ignore_index': ignore_index, 'normalize': True}
197+
sigmoid_X = expit(self.inputs['X'])
198+
term1 = self.inputs['Label'] * np.log(sigmoid_X)
199+
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
200+
out = -term1 - term2
201+
out[np.where(self.inputs['Label'] == ignore_index)] = 0
202+
if self.attrs['normalize']:
203+
out = out / float(
204+
np.where(self.inputs['Label'] != ignore_index)[0].size)
205+
self.outputs = {'Out': out}
206+
207+
def test_check_output(self):
208+
self.check_output()
209+
210+
def test_check_grad(self):
211+
self.check_grad(['X'], 'Out')
212+
213+
214+
class TestSigmoidCrossEntropyWithLogitsOp6(OpTest):
215+
"""Test sigmoid_cross_entropy_with_logit_op with binary label
216+
"""
217+
218+
def setUp(self):
219+
self.op_type = "sigmoid_cross_entropy_with_logits"
220+
batch_size = [10, 10]
221+
num_classes = 20
222+
self.inputs = {
223+
'X': logit(
224+
np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
225+
.astype("float32")),
226+
'Label': np.random.randint(0, 2, tuple(batch_size + [num_classes]))
227+
.astype("float32")
228+
}
229+
230+
# Fw Pass is implemented as elementwise sigmoid followed by
231+
# elementwise logistic loss
232+
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
233+
sigmoid_X = expit(self.inputs['X'])
234+
term1 = self.inputs['Label'] * np.log(sigmoid_X)
235+
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
236+
self.outputs = {'Out': -term1 - term2}
237+
238+
def test_check_output(self):
239+
self.check_output()
240+
241+
def test_check_grad(self):
242+
self.check_grad(['X'], 'Out')
243+
244+
152245
if __name__ == '__main__':
153246
unittest.main()

0 commit comments

Comments
 (0)