Skip to content

Commit 8c7e02b

Browse files
authored
[API] Move the zero_infinity parameter of CTCLoss from the forward function to the constructor (#76156)
* Enhance CTC loss with zero_infinity parameter Added 'zero_infinity' parameter to CTC loss initialization and forward method. * Refactor constructor in loss.py for clarity * Change zero_infinity type to optional boolean * Add Optional import to loss.py * Change default value of zero_infinity to False * Remove unnecessary import of Optional from typing * Update zero_infinity handling in ctc_loss function Refactor zero_infinity parameter handling in CTC loss.
1 parent 09eb01f commit 8c7e02b

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

python/paddle/nn/layer/loss.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,10 +1340,16 @@ class CTCLoss(Layer):
13401340
blank: int
13411341
reduction: _ReduceMode
13421342

1343-
def __init__(self, blank: int = 0, reduction: _ReduceMode = 'mean') -> None:
1343+
def __init__(
1344+
self,
1345+
blank: int = 0,
1346+
reduction: _ReduceMode = 'mean',
1347+
zero_infinity: bool = False,
1348+
) -> None:
13441349
super().__init__()
13451350
self.blank = blank
13461351
self.reduction = reduction
1352+
self.zero_infinity = zero_infinity
13471353

13481354
def forward(
13491355
self,
@@ -1352,7 +1358,6 @@ def forward(
13521358
input_lengths: Tensor,
13531359
label_lengths: Tensor,
13541360
norm_by_times: bool = False,
1355-
zero_infinity: bool = False,
13561361
) -> Tensor:
13571362
return paddle.nn.functional.ctc_loss(
13581363
log_probs,
@@ -1362,7 +1367,7 @@ def forward(
13621367
self.blank,
13631368
self.reduction,
13641369
norm_by_times=norm_by_times,
1365-
zero_infinity=zero_infinity,
1370+
zero_infinity=self.zero_infinity,
13661371
)
13671372

13681373

0 commit comments

Comments
 (0)