1+ import unittest
12import numpy as np
23from mltu .metrics import CERMetric , WERMetric
34
4- from mltu .utils .text_utils import get_wer as wer
5-
6- import cv2
7- import typing
85import numpy as np
96import tensorflow as tf
107
11- if __name__ == "__main__" :
12- import pandas as pd
13- from tqdm import tqdm
14- )
15-
16-
8+ class TestMetrics (unittest .TestCase ):
179
18- # sentences_true = ['helo love', 'helo home', 'helo world']
19- # sentences_pred = ['helo python', 'helo home', 'helo python here']
10+ def to_embeddings ( self , sentences , vocab ):
11+ embeddings , max_len = [], 0
2012
21- # def to_embeddings(sentences, vocab):
22- # embeddings, max_len = [], 0
13+ for sentence in sentences :
14+ embedding = []
15+ for character in sentence :
16+ embedding .append (vocab .index (character ))
17+ embeddings .append (embedding )
18+ max_len = max (max_len , len (embedding ))
19+ return embeddings , max_len
2320
24- # for sentence in sentences:
25- # embedding = []
26- # for character in sentence:
27- # embedding.append(vocab.index(character))
28- # embeddings.append(embedding)
29- # max_len = max(max_len, len(embedding))
30- # return embeddings, max_len
21+ def setUp (self ) -> None :
22+ true_words = ['Who are you' , 'I am a student' , 'I am a teacher' , 'Just different sentence length' ]
23+ pred_words = ['Who are you' , 'I am a ztudent' , 'I am A reacher' , 'Just different length' ]
3124
32- # vocab = set()
33- # for sen in sentences_true + sentences_pred :
34- # for character in sen:
35- # vocab.add(character)
36- # vocab = "".join(vocab)
25+ vocab = set ()
26+ for sen in true_words + pred_words :
27+ for character in sen :
28+ vocab .add (character )
29+ self . vocab = "" .join (vocab )
3730
38- # sen1, max_len = to_embeddings(sentences_true, vocab)
39- # sen2, _ = to_embeddings(sentences_pred, vocab)
31+ sentence_true , max_len_true = self . to_embeddings (true_words , self . vocab )
32+ sentence_pred , max_len_pred = self . to_embeddings (pred_words , self . vocab )
4033
41- # sen_true = [np.pad(sen, (0, max_len - len(sen)), 'constant', constant_values=len(vocab)) for sen in sen1]
42- # sen_pred = [np.pad(sen, (0, 24 - len(sen)), 'constant', constant_values=-1) for sen in sen2]
34+ max_len = max ( max_len_true , max_len_pred )
35+ padding_length = 64
4336
37+ self .sen_true = [np .pad (sen , (0 , max_len - len (sen )), 'constant' , constant_values = len (self .vocab )) for sen in sentence_true ]
38+ self .sen_pred = [np .pad (sen , (0 , padding_length - len (sen )), 'constant' , constant_values = - 1 ) for sen in sentence_pred ]
4439
45- # tf_vocab = tf.constant(list(vocab))
40+ def test_CERMetric (self ):
41+ vocabulary = tf .constant (list (self .vocab ))
42+ cer = CERMetric .get_cer (self .sen_true , self .sen_pred , vocabulary ).numpy ()
4643
47- # distance = WERMetric.get_wer(sen_pred, sen_true, vocab=tf_vocab )
44+ self . assertTrue ( np . array_equal ( cer , np . array ([ 0.0 , 0.071428575 , 0.14285715 , 0.42857143 ], dtype = np . float32 )) )
4845
49- # d = wer(sentences_pred, sentences_true)
46+ def test_WERMetric (self ):
47+ vocabulary = tf .constant (list (self .vocab ))
48+ wer = WERMetric .get_wer (self .sen_true , self .sen_pred , vocabulary ).numpy ()
5049
51- # print(list(distance.numpy()))
52- # print(d)
50+ self .assertTrue (np .array_equal (wer , np .array ([0. , 0.25 , 0.5 , 0.33333334 ], dtype = np .float32 )))
5351
54-
55- word_true = [
56- [1 , 2 , 3 , 4 , 5 , 6 , 1 ],
57- [2 , 3 , 4 , 5 , 6 , 1 , 1 ]
58- ]
59- word_pred = [
60- [1 , 2 , 3 , 4 , 5 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 ],
61- [2 , 3 , 4 , 5 , 6 , 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 ]
62- ]
63- vocabulary = tf .constant (list ("abcdefg" ))
64-
65- distance = CERMetric .get_cer (word_pred , word_true , vocabulary )
52+ if __name__ == "__main__" :
53+ unittest .main ()
0 commit comments