Skip to content

Commit d094cd0

Browse files
authored
change shape of output in cross_entropy (#29414)
1 parent 401cc1e commit d094cd0

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,47 @@ def test_cross_entropy_loss_1d_with_weight_none(self):
219219
self.assertTrue(np.allclose(static_ret, expected))
220220
self.assertTrue(np.allclose(dy_ret_value, expected))
221221

222+
def test_cross_entropy_loss_1d_with_weight_none_func(self):
223+
input_np = np.random.random([100, 200]).astype(np.float64) #N,C
224+
label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) #N
225+
weight_np = np.random.random([200]).astype(np.float64) #C
226+
paddle.enable_static()
227+
prog = fluid.Program()
228+
startup_prog = fluid.Program()
229+
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
230+
) else fluid.CPUPlace()
231+
with fluid.program_guard(prog, startup_prog):
232+
input = fluid.data(name='input', shape=[100, 200], dtype='float64')
233+
label = fluid.data(name='label', shape=[100], dtype='int64')
234+
weight = fluid.data(name='weight', shape=[200], dtype='float64')
235+
ret = paddle.nn.functional.cross_entropy(
236+
input, label, weight=weight, reduction='none')
237+
238+
exe = fluid.Executor(place)
239+
static_ret = exe.run(prog,
240+
feed={
241+
'input': input_np,
242+
'label': label_np,
243+
"weight": weight_np
244+
},
245+
fetch_list=[ret])
246+
static_ret = np.squeeze(static_ret)
247+
self.assertIsNotNone(static_ret)
248+
with fluid.dygraph.guard():
249+
dy_ret = paddle.nn.functional.cross_entropy(
250+
fluid.dygraph.to_variable(input_np),
251+
fluid.dygraph.to_variable(label_np),
252+
weight=fluid.dygraph.to_variable(weight_np),
253+
reduction='none')
254+
dy_ret_value = dy_ret.numpy()
255+
dy_ret_value = np.squeeze(dy_ret_value)
256+
self.assertIsNotNone(dy_ret_value)
257+
expected = cross_entropy_loss_1d(
258+
input_np, label_np, weight=weight_np, reduction='none')
259+
self.assertTrue(np.allclose(static_ret, dy_ret_value))
260+
self.assertTrue(np.allclose(static_ret, expected))
261+
self.assertTrue(np.allclose(dy_ret_value, expected))
262+
222263
def test_cross_entropy_loss_1d_mean(self):
223264
input_np = np.random.random([100, 200]).astype(np.float64) #N,C
224265
label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) #N,1

python/paddle/nn/functional/loss.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,8 @@ def cross_entropy(input,
12361236
else:
12371237
return core.ops.mean(out)
12381238
else:
1239+
if input_dims - 1 == label_dims:
1240+
out = paddle.squeeze(out, axis=axis)
12391241
return out
12401242

12411243
fluid.data_feeder.check_variable_and_dtype(
@@ -1267,6 +1269,8 @@ def cross_entropy(input,
12671269
else:
12681270
return paddle.mean(out, name=name)
12691271
else:
1272+
if input_dims - 1 == label_dims:
1273+
out = paddle.squeeze(out, axis=axis)
12701274
return out
12711275

12721276

0 commit comments

Comments
 (0)