Skip to content

Commit 57bf2da

Browse files
committed
weoking on CER and WER TensorFlow metrics as well, writing unittests for them
1 parent 940912e commit 57bf2da

File tree

9 files changed

+626
-6
lines changed

9 files changed

+626
-6
lines changed

.vscode/settings.json

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
11
{
2-
"python.analysis.typeCheckingMode": "off"
2+
"python.analysis.typeCheckingMode": "off",
3+
"python.testing.unittestArgs": [
4+
"-v",
5+
"-s",
6+
"./Tests",
7+
"-p",
8+
"*test*.py"
9+
],
10+
"python.testing.pytestEnabled": false,
11+
"python.testing.unittestEnabled": true
312
}

Tests/test_metrics.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import numpy as np
2+
from mltu.metrics import CERMetric, WERMetric
3+
4+
from mltu.utils.text_utils import get_wer as wer
5+
6+
import cv2
7+
import typing
8+
import numpy as np
9+
import tensorflow as tf
10+
11+
if __name__ == "__main__":
12+
import pandas as pd
13+
from tqdm import tqdm
14+
)
15+
16+
17+
18+
# sentences_true = ['helo love', 'helo home', 'helo world']
19+
# sentences_pred = ['helo python', 'helo home', 'helo python here']
20+
21+
# def to_embeddings(sentences, vocab):
22+
# embeddings, max_len = [], 0
23+
24+
# for sentence in sentences:
25+
# embedding = []
26+
# for character in sentence:
27+
# embedding.append(vocab.index(character))
28+
# embeddings.append(embedding)
29+
# max_len = max(max_len, len(embedding))
30+
# return embeddings, max_len
31+
32+
# vocab = set()
33+
# for sen in sentences_true + sentences_pred:
34+
# for character in sen:
35+
# vocab.add(character)
36+
# vocab = "".join(vocab)
37+
38+
# sen1, max_len = to_embeddings(sentences_true, vocab)
39+
# sen2, _ = to_embeddings(sentences_pred, vocab)
40+
41+
# sen_true = [np.pad(sen, (0, max_len - len(sen)), 'constant', constant_values=len(vocab)) for sen in sen1]
42+
# sen_pred = [np.pad(sen, (0, 24 - len(sen)), 'constant', constant_values=-1) for sen in sen2]
43+
44+
45+
# tf_vocab = tf.constant(list(vocab))
46+
47+
# distance = WERMetric.get_wer(sen_pred, sen_true, vocab=tf_vocab)
48+
49+
# d = wer(sentences_pred, sentences_true)
50+
51+
# print(list(distance.numpy()))
52+
# print(d)
53+
54+
55+
word_true = [
56+
[1, 2, 3, 4, 5, 6, 1],
57+
[2, 3, 4, 5, 6, 1, 1]
58+
]
59+
word_pred = [
60+
[1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
61+
[2, 3, 4, 5, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
62+
]
63+
vocabulary = tf.constant(list("abcdefg"))
64+
65+
distance = CERMetric.get_cer(word_pred, word_true, vocabulary)

Tests/test_text_utils.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import unittest
2+
3+
from mltu.utils.text_utils import edit_distance, get_cer, get_wer
4+
5+
class TestTextUtils(unittest.TestCase):
6+
7+
def test_edit_distance(self):
8+
""" This unit test includes several test cases to cover different scenarios, including no errors,
9+
substitution errors, insertion errors, deletion errors, and a more complex case with multiple
10+
errors. It also includes a test case for empty input.
11+
"""
12+
# Test simple case with no errors
13+
prediction_tokens = ['A', 'B', 'C']
14+
reference_tokens = ['A', 'B', 'C']
15+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 0)
16+
17+
# Test simple case with one substitution error
18+
prediction_tokens = ['A', 'B', 'D']
19+
reference_tokens = ['A', 'B', 'C']
20+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 1)
21+
22+
# Test simple case with one insertion error
23+
prediction_tokens = ['A', 'B', 'C']
24+
reference_tokens = ['A', 'B', 'C', 'D']
25+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 1)
26+
27+
# Test simple case with one deletion error
28+
prediction_tokens = ['A', 'B']
29+
reference_tokens = ['A', 'B', 'C']
30+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 1)
31+
32+
# Test more complex case with multiple errors
33+
prediction_tokens = ['A', 'B', 'C', 'D', 'E']
34+
reference_tokens = ['A', 'C', 'B', 'F', 'E']
35+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 3)
36+
37+
# Test empty input
38+
prediction_tokens = []
39+
reference_tokens = []
40+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 0)
41+
42+
def test_get_cer(self):
43+
# Test simple case with no errors
44+
preds = ['A B C']
45+
target = ['A B C']
46+
self.assertEqual(get_cer(preds, target), 0)
47+
48+
# Test simple case with one character error
49+
preds = ['A B C']
50+
target = ['A B D']
51+
self.assertEqual(get_cer(preds, target), 1/5)
52+
53+
# Test simple case with multiple character errors
54+
preds = ['A B C']
55+
target = ['D E F']
56+
self.assertEqual(get_cer(preds, target), 3/5)
57+
58+
# Test empty input
59+
preds = []
60+
target = []
61+
self.assertEqual(get_cer(preds, target), 0)
62+
63+
# Test simple case with different word lengths
64+
preds = ['ABC']
65+
target = ['ABCDEFG']
66+
self.assertEqual(get_cer(preds, target), 4/7)
67+
68+
def test_get_wer(self):
69+
# Test simple case with no errors
70+
preds = 'A B C'
71+
target = 'A B C'
72+
self.assertEqual(get_wer(preds, target), [0, 0, 0])
73+
74+
# Test simple case with one word error
75+
preds = 'A B C'
76+
target = 'A B D'
77+
self.assertEqual(get_wer(preds, target), [0, 0, 1])
78+
79+
# Test simple case with multiple word errors
80+
preds = 'A B C'
81+
target = 'D E F'
82+
self.assertEqual(get_wer(preds, target), [1, 1, 1])
83+
84+
# Test empty input
85+
preds = ""
86+
target = ""
87+
self.assertEqual(get_wer(preds, target), [])
88+
89+
# Test simple case with different sentence lengths
90+
preds = ['ABC']
91+
target = ['ABC DEF']
92+
self.assertEqual(get_wer(preds, target), [1/2])
93+
94+
if __name__ == '__main__':
95+
unittest.main()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import stow
2+
from datetime import datetime
3+
4+
from mltu.configs import BaseModelConfigs
5+
6+
class ModelConfigs(BaseModelConfigs):
7+
def __init__(self):
8+
super().__init__()
9+
self.model_path = stow.join('Models/04_sentence_recognition', datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
10+
self.vocab = ''
11+
self.height = 96
12+
self.width = 1408
13+
self.max_text_length = 0
14+
self.batch_size = 32
15+
self.learning_rate = 0.003
16+
self.train_epochs = 1000
17+
self.train_workers = 20
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import cv2
2+
import typing
3+
import numpy as np
4+
5+
from mltu.inferenceModel import OnnxInferenceModel
6+
from mltu.utils.text_utils import ctc_decoder, get_cer, get_wer
7+
8+
class ImageToWordModel(OnnxInferenceModel):
9+
def __init__(self, char_list: typing.Union[str, list], *args, **kwargs):
10+
super().__init__(*args, **kwargs)
11+
self.char_list = char_list
12+
13+
def predict(self, image: np.ndarray):
14+
image = cv2.resize(image, self.input_shape[:2][::-1])
15+
16+
image_pred = np.expand_dims(image, axis=0).astype(np.float32)
17+
18+
preds = self.model.run(None, {self.input_name: image_pred})[0]
19+
20+
text = ctc_decoder(preds, self.char_list)[0]
21+
22+
return text
23+
24+
if __name__ == "__main__":
25+
import pandas as pd
26+
from tqdm import tqdm
27+
from mltu.configs import BaseModelConfigs
28+
29+
configs = BaseModelConfigs.load("Models/04_sentence_recognition/202301041513/configs.yaml")
30+
31+
model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab)
32+
33+
df = pd.read_csv("Models/04_sentence_recognition/202301041513/val.csv").values.tolist()
34+
35+
accum_cer, accum_wer = [], []
36+
for image_path, label in tqdm(df):
37+
image = cv2.imread(image_path)
38+
39+
prediction_text = model.predict(image)
40+
41+
cer = get_cer(prediction_text, label)
42+
wer = get_wer(prediction_text, label)
43+
print(f"Image: {image_path}; Label: ({label}); Prediction: ({prediction_text}); CER: {cer}; WER: {wer}")
44+
45+
accum_cer.append(cer)
46+
accum_wer.append(wer)
47+
48+
print(f"Average CER: {np.average(accum_cer)}, Average WER: {np.average(accum_wer)}")
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from keras import layers
2+
from keras.models import Model
3+
4+
from mltu.model_utils import residual_block
5+
6+
def train_model(input_dim, output_dim, activation='leaky_relu', dropout=0.2):
7+
8+
inputs = layers.Input(shape=input_dim, name="input")
9+
10+
# normalize images here instead in preprocessing step
11+
input = layers.Lambda(lambda x: x / 255)(inputs)
12+
13+
x1 = residual_block(input, 32, activation=activation, skip_conv=True, strides=1, dropout=dropout)
14+
15+
x2 = residual_block(x1, 32, activation=activation, skip_conv=True, strides=2, dropout=dropout)
16+
x3 = residual_block(x2, 32, activation=activation, skip_conv=False, strides=1, dropout=dropout)
17+
18+
x4 = residual_block(x3, 64, activation=activation, skip_conv=True, strides=2, dropout=dropout)
19+
x5 = residual_block(x4, 64, activation=activation, skip_conv=False, strides=1, dropout=dropout)
20+
21+
x6 = residual_block(x5, 128, activation=activation, skip_conv=True, strides=2, dropout=dropout)
22+
x7 = residual_block(x6, 128, activation=activation, skip_conv=True, strides=1, dropout=dropout)
23+
24+
x8 = residual_block(x7, 128, activation=activation, skip_conv=True, strides=2, dropout=dropout)
25+
x9 = residual_block(x8, 128, activation=activation, skip_conv=False, strides=1, dropout=dropout)
26+
27+
squeezed = layers.Reshape((x9.shape[-3] * x9.shape[-2], x9.shape[-1]))(x9)
28+
29+
blstm = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(squeezed)
30+
blstm = layers.Dropout(dropout)(blstm)
31+
32+
output = layers.Dense(output_dim + 1, activation='softmax', name="output")(blstm)
33+
34+
model = Model(inputs=inputs, outputs=output)
35+
return model

0 commit comments

Comments
 (0)