Skip to content

Commit 8036231

Browse files
azticeSigureMo
andauthored
[API] Support zero_infinity in ctc_loss (#75742)
--------- Co-authored-by: Nyakku Shigure <[email protected]>
1 parent 3b5e90a commit 8036231

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

python/paddle/nn/functional/loss.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1911,6 +1911,7 @@ def ctc_loss(
19111911
blank: int = 0,
19121912
reduction: _ReduceMode = 'mean',
19131913
norm_by_times: bool = False,
1914+
zero_infinity: bool = False,
19141915
) -> Tensor:
19151916
"""
19161917
@@ -1927,6 +1928,7 @@ def ctc_loss(
19271928
blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default: 0.
19281929
reduction (str, optional): Indicate how to average the loss, the candidates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default: ``'mean'``.
19291930
norm_by_times (bool, optional): Whether to normalize the gradients by the number of time-step, which is also the sequence's length. There is no need to normalize the gradients if reduction mode is 'mean'. Default: False.
1931+
zero_infinity (bool, optional): If True, set infinite loss to zero. Default: False.
19301932
19311933
Returns:
19321934
Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is []. Data type is the same as ``log_probs``.
@@ -2041,8 +2043,17 @@ def warpctc(
20412043
loss_out = warpctc(
20422044
log_probs, labels, blank, norm_by_times, input_lengths, label_lengths
20432045
)
2044-
20452046
loss_out = paddle.squeeze(loss_out, [-1])
2047+
2048+
if zero_infinity:
2049+
inf_mask = paddle.isinf(loss_out)
2050+
zero_value = paddle.zeros_like(loss_out)
2051+
loss_out = paddle.where(
2052+
condition=inf_mask,
2053+
x=zero_value,
2054+
y=loss_out,
2055+
)
2056+
20462057
assert reduction in ['mean', 'sum', 'none']
20472058
if reduction == 'mean':
20482059
loss_out = paddle.mean(loss_out / label_lengths.astype(loss_out.dtype))

python/paddle/nn/layer/loss.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,7 @@ def forward(
13521352
input_lengths: Tensor,
13531353
label_lengths: Tensor,
13541354
norm_by_times: bool = False,
1355+
zero_infinity: bool = False,
13551356
) -> Tensor:
13561357
return paddle.nn.functional.ctc_loss(
13571358
log_probs,
@@ -1361,6 +1362,7 @@ def forward(
13611362
self.blank,
13621363
self.reduction,
13631364
norm_by_times=norm_by_times,
1365+
zero_infinity=zero_infinity,
13641366
)
13651367

13661368

test/legacy_test/test_warpctc_op.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,94 @@ def test_functional_api():
754754

755755
test_functional_api()
756756

757+
def test_ctc_loss_zero_infinity(self):
758+
max_time = 1
759+
batch = 1
760+
n_class = 8
761+
logits_np = np.random.randn(max_time, batch, n_class).astype("float32")
762+
labels_np = np.random.randint(0, n_class - 1, (batch, 3)).astype(
763+
"int32"
764+
)
765+
input_len_np = np.array([1], dtype=np.int64)
766+
label_len_np = np.array([3], dtype=np.int64)
767+
768+
paddle.enable_static()
769+
main_program = paddle.static.Program()
770+
startup_program = paddle.static.Program()
771+
772+
with paddle.static.program_guard(main_program, startup_program):
773+
logits = paddle.static.data(
774+
name="logits_il",
775+
shape=[max_time, batch, n_class],
776+
dtype="float32",
777+
)
778+
labels = paddle.static.data(
779+
name="labels_il", shape=[batch, 3], dtype="int32"
780+
)
781+
input_len = paddle.static.data(
782+
name="input_len_il", shape=[batch], dtype="int64"
783+
)
784+
label_len = paddle.static.data(
785+
name="label_len_il", shape=[batch], dtype="int64"
786+
)
787+
788+
loss = paddle.nn.functional.ctc_loss(
789+
log_probs=logits,
790+
labels=labels,
791+
input_lengths=input_len,
792+
label_lengths=label_len,
793+
reduction="none",
794+
zero_infinity=True,
795+
blank=n_class - 1,
796+
)
797+
798+
exe = paddle.static.Executor()
799+
loss_val = exe.run(
800+
main_program,
801+
feed={
802+
"logits_il": logits_np,
803+
"labels_il": labels_np,
804+
"input_len_il": input_len_np,
805+
"label_len_il": label_len_np,
806+
},
807+
fetch_list=[loss],
808+
)[0]
809+
810+
# illegal sample -> 0
811+
np.testing.assert_allclose(loss_val, [0.0], atol=1e-6)
812+
813+
paddle.disable_static()
814+
815+
def test_ctc_loss_zero_infinity_dygraph(self):
816+
max_time = 1
817+
batch = 1
818+
n_class = 8
819+
820+
logits_np = np.random.randn(max_time, batch, n_class).astype("float32")
821+
labels_np = np.random.randint(0, n_class - 1, (batch, 3)).astype(
822+
"int32"
823+
)
824+
input_len_np = np.array([1], dtype=np.int64)
825+
label_len_np = np.array([3], dtype=np.int64)
826+
827+
paddle.disable_static()
828+
logits = paddle.to_tensor(logits_np)
829+
labels = paddle.to_tensor(labels_np)
830+
input_len = paddle.to_tensor(input_len_np)
831+
label_len = paddle.to_tensor(label_len_np)
832+
833+
loss = paddle.nn.functional.ctc_loss(
834+
log_probs=logits,
835+
labels=labels,
836+
input_lengths=input_len,
837+
label_lengths=label_len,
838+
reduction="none",
839+
zero_infinity=True,
840+
blank=n_class - 1,
841+
)
842+
843+
np.testing.assert_allclose(loss.numpy(), [0.0], rtol=1e-6)
844+
757845

758846
if __name__ == "__main__":
759847
unittest.main()

0 commit comments

Comments
 (0)