Skip to content

Commit 839e5fa

Browse files
committed
✍️ update metrics
1 parent 7c8d239 commit 839e5fa

File tree

1 file changed

+28
-26
lines changed

1 file changed

+28
-26
lines changed

tensorflow_asr/utils/metrics.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,23 @@
1818
from .utils import bytes_to_string
1919

2020

21+
def _wer(decode, target):
22+
decode = bytes_to_string(decode)
23+
target = bytes_to_string(target)
24+
dis = 0.0
25+
length = 0.0
26+
for dec, tar in zip(decode, target):
27+
words = set(dec.split() + tar.split())
28+
word2char = dict(zip(words, range(len(words))))
29+
30+
new_decode = [chr(word2char[w]) for w in dec.split()]
31+
new_target = [chr(word2char[w]) for w in tar.split()]
32+
33+
dis += distance.edit_distance(''.join(new_decode), ''.join(new_target))
34+
length += len(tar.split())
35+
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)
36+
37+
2138
def wer(_decode: tf.Tensor, _target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
2239
"""Word Error Rate
2340
@@ -28,23 +45,18 @@ def wer(_decode: tf.Tensor, _target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
2845
Returns:
2946
tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text
3047
"""
31-
def fn(decode, target):
32-
decode = bytes_to_string(decode)
33-
target = bytes_to_string(target)
34-
dis = 0.0
35-
length = 0.0
36-
for dec, tar in zip(decode, target):
37-
words = set(dec.split() + tar.split())
38-
word2char = dict(zip(words, range(len(words))))
39-
40-
new_decode = [chr(word2char[w]) for w in dec.split()]
41-
new_target = [chr(word2char[w]) for w in tar.split()]
48+
return tf.numpy_function(_wer, inp=[_decode, _target], Tout=[tf.float32, tf.float32])
4249

43-
dis += distance.edit_distance(''.join(new_decode), ''.join(new_target))
44-
length += len(tar.split())
45-
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)
4650

47-
return tf.numpy_function(fn, inp=[_decode, _target], Tout=[tf.float32, tf.float32])
51+
def _cer(decode, target):
52+
decode = bytes_to_string(decode)
53+
target = bytes_to_string(target)
54+
dis = 0
55+
length = 0
56+
for dec, tar in zip(decode, target):
57+
dis += distance.edit_distance(dec, tar)
58+
length += len(tar)
59+
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)
4860

4961

5062
def cer(_decode: tf.Tensor, _target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
@@ -57,17 +69,7 @@ def cer(_decode: tf.Tensor, _target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
5769
Returns:
5870
tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
5971
"""
60-
def fn(decode, target):
61-
decode = bytes_to_string(decode)
62-
target = bytes_to_string(target)
63-
dis = 0
64-
length = 0
65-
for dec, tar in zip(decode, target):
66-
dis += distance.edit_distance(dec, tar)
67-
length += len(tar)
68-
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)
69-
70-
return tf.numpy_function(fn, inp=[_decode, _target], Tout=[tf.float32, tf.float32])
72+
return tf.numpy_function(_cer, inp=[_decode, _target], Tout=[tf.float32, tf.float32])
7173

7274

7375
def tf_cer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:

0 commit comments

Comments
 (0)