1313# limitations under the License.
1414
1515from typing import Tuple
16- import numpy as np
1716import tensorflow as tf
1817from nltk .metrics import distance
1918from .utils import bytes_to_string
2019
2120
22- def wer (decode : np .ndarray , target : np .ndarray ) -> Tuple [tf .Tensor , tf .Tensor ]:
23- """Word Error Rate
24-
25- Args:
26- decode (np.ndarray): array of prediction texts
27- target (np.ndarray): array of groundtruth texts
28-
29- Returns:
30- tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text
31- """
21+ def _wer (decode , target ):
3222 decode = bytes_to_string (decode )
3323 target = bytes_to_string (target )
3424 dis = 0.0
@@ -45,16 +35,20 @@ def wer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
4535 return tf .convert_to_tensor (dis , tf .float32 ), tf .convert_to_tensor (length , tf .float32 )
4636
4737
48- def cer ( decode : np . ndarray , target : np . ndarray ) -> Tuple [tf .Tensor , tf .Tensor ]:
49- """Character Error Rate
38+ def wer ( _decode : tf . Tensor , _target : tf . Tensor ) -> Tuple [tf .Tensor , tf .Tensor ]:
39+ """Word Error Rate
5040
5141 Args:
5242 decode (np.ndarray): array of prediction texts
5343 target (np.ndarray): array of groundtruth texts
5444
5545 Returns:
56- tuple: a tuple of tf.Tensor of (edit distances, number of characters ) of each text
46+ tuple: a tuple of tf.Tensor of (edit distances, number of words ) of each text
5747 """
48+ return tf .numpy_function (_wer , inp = [_decode , _target ], Tout = [tf .float32 , tf .float32 ])
49+
50+
51+ def _cer (decode , target ):
5852 decode = bytes_to_string (decode )
5953 target = bytes_to_string (target )
6054 dis = 0
@@ -65,6 +59,36 @@ def cer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
6559 return tf .convert_to_tensor (dis , tf .float32 ), tf .convert_to_tensor (length , tf .float32 )
6660
6761
62+ def cer (_decode : tf .Tensor , _target : tf .Tensor ) -> Tuple [tf .Tensor , tf .Tensor ]:
63+ """Character Error Rate
64+
65+ Args:
66+ decode (np.ndarray): array of prediction texts
67+ target (np.ndarray): array of groundtruth texts
68+
69+ Returns:
70+ tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
71+ """
72+ return tf .numpy_function (_cer , inp = [_decode , _target ], Tout = [tf .float32 , tf .float32 ])
73+
74+
75+ def tf_cer (decode : tf .Tensor , target : tf .Tensor ) -> Tuple [tf .Tensor , tf .Tensor ]:
76+ """Tensorflwo Charactor Error rate
77+
78+ Args:
79+ decoder (tf.Tensor): tensor shape [B]
80+ target (tf.Tensor): tensor shape [B]
81+
82+ Returns:
83+ tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
84+ """
85+ decode = tf .strings .bytes_split (decode ) # [B, N]
86+ target = tf .strings .bytes_split (target ) # [B, M]
87+ distances = tf .edit_distance (decode .to_sparse (), target .to_sparse (), normalize = False ) # [B]
88+ lengths = tf .cast (target .row_lengths (axis = 1 ), dtype = tf .float32 ) # [B]
89+ return tf .reduce_sum (distances ), tf .reduce_sum (lengths )
90+
91+
6892class ErrorRate (tf .keras .metrics .Metric ):
6993 """ Metric for WER and CER """
7094
@@ -75,10 +99,9 @@ def __init__(self, func, name="error_rate", **kwargs):
7599 self .func = func
76100
77101 def update_state (self , decode : tf .Tensor , target : tf .Tensor ):
78- n , d = tf . numpy_function ( self .func , inp = [ decode , target ], Tout = [ tf . float32 , tf . float32 ] )
102+ n , d = self .func ( decode , target )
79103 self .numerator .assign_add (n )
80104 self .denominator .assign_add (d )
81105
82106 def result (self ):
83- if self .denominator == 0.0 : return 0.0
84- return (self .numerator / self .denominator ) * 100
107+ return tf .math .divide_no_nan (self .numerator , self .denominator ) * 100
0 commit comments