@@ -10,28 +10,28 @@ def test_edit_distance(self):
1010 errors. It also includes a test case for empty input.
1111 """
1212 # Test simple case with no errors
13- prediction_tokens = ['A' , 'B' , 'C' ]
14- reference_tokens = ['A' , 'B' , 'C' ]
13+ prediction_tokens = ["A" , "B" , "C" ]
14+ reference_tokens = ["A" , "B" , "C" ]
1515 self .assertEqual (edit_distance (prediction_tokens , reference_tokens ), 0 )
1616
1717 # Test simple case with one substitution error
18- prediction_tokens = ['A' , 'B' , 'D' ]
19- reference_tokens = ['A' , 'B' , 'C' ]
18+ prediction_tokens = ["A" , "B" , "D" ]
19+ reference_tokens = ["A" , "B" , "C" ]
2020 self .assertEqual (edit_distance (prediction_tokens , reference_tokens ), 1 )
2121
2222 # Test simple case with one insertion error
23- prediction_tokens = ['A' , 'B' , 'C' ]
24- reference_tokens = ['A' , 'B' , 'C' , 'D' ]
23+ prediction_tokens = ["A" , "B" , "C" ]
24+ reference_tokens = ["A" , "B" , "C" , "D" ]
2525 self .assertEqual (edit_distance (prediction_tokens , reference_tokens ), 1 )
2626
2727 # Test simple case with one deletion error
28- prediction_tokens = ['A' , 'B' ]
29- reference_tokens = ['A' , 'B' , 'C' ]
28+ prediction_tokens = ["A" , "B" ]
29+ reference_tokens = ["A" , "B" , "C" ]
3030 self .assertEqual (edit_distance (prediction_tokens , reference_tokens ), 1 )
3131
3232 # Test more complex case with multiple errors
33- prediction_tokens = ['A' , 'B' , 'C' , 'D' , 'E' ]
34- reference_tokens = ['A' , 'C' , 'B' , 'F' , 'E' ]
33+ prediction_tokens = ["A" , "B" , "C" , "D" , "E" ]
34+ reference_tokens = ["A" , "C" , "B" , "F" , "E" ]
3535 self .assertEqual (edit_distance (prediction_tokens , reference_tokens ), 3 )
3636
3737 # Test empty input
@@ -41,18 +41,18 @@ def test_edit_distance(self):
4141
4242 def test_get_cer (self ):
4343 # Test simple case with no errors
44- preds = [' A B C' ]
45- target = [' A B C' ]
44+ preds = [" A B C" ]
45+ target = [" A B C" ]
4646 self .assertEqual (get_cer (preds , target ), 0 )
4747
4848 # Test simple case with one character error
49- preds = [' A B C' ]
50- target = [' A B D' ]
49+ preds = [" A B C" ]
50+ target = [" A B D" ]
5151 self .assertEqual (get_cer (preds , target ), 1 / 5 )
5252
5353 # Test simple case with multiple character errors
54- preds = [' A B C' ]
55- target = [' D E F' ]
54+ preds = [" A B C" ]
55+ target = [" D E F" ]
5656 self .assertEqual (get_cer (preds , target ), 3 / 5 )
5757
5858 # Test empty input
@@ -61,24 +61,24 @@ def test_get_cer(self):
6161 self .assertEqual (get_cer (preds , target ), 0 )
6262
6363 # Test simple case with different word lengths
64- preds = [' ABC' ]
65- target = [' ABCDEFG' ]
64+ preds = [" ABC" ]
65+ target = [" ABCDEFG" ]
6666 self .assertEqual (get_cer (preds , target ), 4 / 7 )
6767
6868 def test_get_wer (self ):
6969 # Test simple case with no errors
70- preds = ' A B C'
71- target = ' A B C'
70+ preds = " A B C"
71+ target = " A B C"
7272 self .assertEqual (get_wer (preds , target ), 0 )
7373
7474 # Test simple case with one word error
75- preds = ' A B C'
76- target = ' A B D'
75+ preds = " A B C"
76+ target = " A B D"
7777 self .assertEqual (get_wer (preds , target ), 1 / 3 )
7878
7979 # Test simple case with multiple word errors
80- preds = ' A B C'
81- target = ' D E F'
80+ preds = " A B C"
81+ target = " D E F"
8282 self .assertEqual (get_wer (preds , target ), 1 )
8383
8484 # Test empty input
@@ -87,9 +87,10 @@ def test_get_wer(self):
8787 self .assertEqual (get_wer (preds , target ), 0 )
8888
8989 # Test simple case with different sentence lengths
90- preds = [' ABC' ]
91- target = [' ABC DEF' ]
90+ preds = [" ABC" ]
91+ target = [" ABC DEF" ]
9292 self .assertEqual (get_wer (preds , target ), 1 )
9393
94- if __name__ == '__main__' :
95- unittest .main ()
94+
95+ if __name__ == "__main__" :
96+ unittest .main ()
0 commit comments