1818from .utils import bytes_to_string
1919
2020
21+ def _wer (decode , target ):
22+ decode = bytes_to_string (decode )
23+ target = bytes_to_string (target )
24+ dis = 0.0
25+ length = 0.0
26+ for dec , tar in zip (decode , target ):
27+ words = set (dec .split () + tar .split ())
28+ word2char = dict (zip (words , range (len (words ))))
29+
30+ new_decode = [chr (word2char [w ]) for w in dec .split ()]
31+ new_target = [chr (word2char [w ]) for w in tar .split ()]
32+
33+ dis += distance .edit_distance ('' .join (new_decode ), '' .join (new_target ))
34+ length += len (tar .split ())
35+ return tf .convert_to_tensor (dis , tf .float32 ), tf .convert_to_tensor (length , tf .float32 )
36+
37+
2138def wer (_decode : tf .Tensor , _target : tf .Tensor ) -> Tuple [tf .Tensor , tf .Tensor ]:
2239 """Word Error Rate
2340
@@ -28,23 +45,18 @@ def wer(_decode: tf.Tensor, _target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
2845 Returns:
2946 tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text
3047 """
31- def fn (decode , target ):
32- decode = bytes_to_string (decode )
33- target = bytes_to_string (target )
34- dis = 0.0
35- length = 0.0
36- for dec , tar in zip (decode , target ):
37- words = set (dec .split () + tar .split ())
38- word2char = dict (zip (words , range (len (words ))))
39-
40- new_decode = [chr (word2char [w ]) for w in dec .split ()]
41- new_target = [chr (word2char [w ]) for w in tar .split ()]
48+ return tf .numpy_function (_wer , inp = [_decode , _target ], Tout = [tf .float32 , tf .float32 ])
4249
43- dis += distance .edit_distance ('' .join (new_decode ), '' .join (new_target ))
44- length += len (tar .split ())
45- return tf .convert_to_tensor (dis , tf .float32 ), tf .convert_to_tensor (length , tf .float32 )
4650
47- return tf .numpy_function (fn , inp = [_decode , _target ], Tout = [tf .float32 , tf .float32 ])
51+ def _cer (decode , target ):
52+ decode = bytes_to_string (decode )
53+ target = bytes_to_string (target )
54+ dis = 0
55+ length = 0
56+ for dec , tar in zip (decode , target ):
57+ dis += distance .edit_distance (dec , tar )
58+ length += len (tar )
59+ return tf .convert_to_tensor (dis , tf .float32 ), tf .convert_to_tensor (length , tf .float32 )
4860
4961
5062def cer (_decode : tf .Tensor , _target : tf .Tensor ) -> Tuple [tf .Tensor , tf .Tensor ]:
@@ -57,17 +69,7 @@ def cer(_decode: tf.Tensor, _target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
5769 Returns:
5870 tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
5971 """
60- def fn (decode , target ):
61- decode = bytes_to_string (decode )
62- target = bytes_to_string (target )
63- dis = 0
64- length = 0
65- for dec , tar in zip (decode , target ):
66- dis += distance .edit_distance (dec , tar )
67- length += len (tar )
68- return tf .convert_to_tensor (dis , tf .float32 ), tf .convert_to_tensor (length , tf .float32 )
69-
70- return tf .numpy_function (fn , inp = [_decode , _target ], Tout = [tf .float32 , tf .float32 ])
72+ return tf .numpy_function (_cer , inp = [_decode , _target ], Tout = [tf .float32 , tf .float32 ])
7173
7274
7375def tf_cer (decode : tf .Tensor , target : tf .Tensor ) -> Tuple [tf .Tensor , tf .Tensor ]:
0 commit comments