1+ import unittest
2+
3+ from mltu .utils .text_utils import edit_distance , get_cer , get_wer
4+
5+ class TestTextUtils (unittest .TestCase ):
6+
7+ def test_edit_distance (self ):
8+ """ This unit test includes several test cases to cover different scenarios, including no errors,
9+ substitution errors, insertion errors, deletion errors, and a more complex case with multiple
10+ errors. It also includes a test case for empty input.
11+ """
12+ # Test simple case with no errors
13+ prediction_tokens = ['A' , 'B' , 'C' ]
14+ reference_tokens = ['A' , 'B' , 'C' ]
15+ self .assertEqual (edit_distance (prediction_tokens , reference_tokens ), 0 )
16+
17+ # Test simple case with one substitution error
18+ prediction_tokens = ['A' , 'B' , 'D' ]
19+ reference_tokens = ['A' , 'B' , 'C' ]
20+ self .assertEqual (edit_distance (prediction_tokens , reference_tokens ), 1 )
21+
22+ # Test simple case with one insertion error
23+ prediction_tokens = ['A' , 'B' , 'C' ]
24+ reference_tokens = ['A' , 'B' , 'C' , 'D' ]
25+ self .assertEqual (edit_distance (prediction_tokens , reference_tokens ), 1 )
26+
27+ # Test simple case with one deletion error
28+ prediction_tokens = ['A' , 'B' ]
29+ reference_tokens = ['A' , 'B' , 'C' ]
30+ self .assertEqual (edit_distance (prediction_tokens , reference_tokens ), 1 )
31+
32+ # Test more complex case with multiple errors
33+ prediction_tokens = ['A' , 'B' , 'C' , 'D' , 'E' ]
34+ reference_tokens = ['A' , 'C' , 'B' , 'F' , 'E' ]
35+ self .assertEqual (edit_distance (prediction_tokens , reference_tokens ), 3 )
36+
37+ # Test empty input
38+ prediction_tokens = []
39+ reference_tokens = []
40+ self .assertEqual (edit_distance (prediction_tokens , reference_tokens ), 0 )
41+
42+ def test_get_cer (self ):
43+ # Test simple case with no errors
44+ preds = ['A B C' ]
45+ target = ['A B C' ]
46+ self .assertEqual (get_cer (preds , target ), 0 )
47+
48+ # Test simple case with one character error
49+ preds = ['A B C' ]
50+ target = ['A B D' ]
51+ self .assertEqual (get_cer (preds , target ), 1 / 5 )
52+
53+ # Test simple case with multiple character errors
54+ preds = ['A B C' ]
55+ target = ['D E F' ]
56+ self .assertEqual (get_cer (preds , target ), 3 / 5 )
57+
58+ # Test empty input
59+ preds = []
60+ target = []
61+ self .assertEqual (get_cer (preds , target ), 0 )
62+
63+ # Test simple case with different word lengths
64+ preds = ['ABC' ]
65+ target = ['ABCDEFG' ]
66+ self .assertEqual (get_cer (preds , target ), 4 / 7 )
67+
68+ def test_get_wer (self ):
69+ # Test simple case with no errors
70+ preds = 'A B C'
71+ target = 'A B C'
72+ self .assertEqual (get_wer (preds , target ), 0 )
73+
74+ # Test simple case with one word error
75+ preds = 'A B C'
76+ target = 'A B D'
77+ self .assertEqual (get_wer (preds , target ), 1 / 3 )
78+
79+ # Test simple case with multiple word errors
80+ preds = 'A B C'
81+ target = 'D E F'
82+ self .assertEqual (get_wer (preds , target ), 1 )
83+
84+ # Test empty input
85+ preds = ""
86+ target = ""
87+ self .assertEqual (get_wer (preds , target ), 0 )
88+
89+ # Test simple case with different sentence lengths
90+ preds = ['ABC' ]
91+ target = ['ABC DEF' ]
92+ self .assertEqual (get_wer (preds , target ), 1 )
93+
94+ if __name__ == '__main__' :
95+ unittest .main ()
0 commit comments