Skip to content

Commit cd89cbb

Browse files
Merge pull request #3 from pythonlessons/feature/sentence_recognition
Feature/sentence recognition
2 parents b2a0a75 + f98cd7e commit cd89cbb

File tree

21 files changed

+1222
-100
lines changed

21 files changed

+1222
-100
lines changed

.vscode/settings.json

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
11
{
2-
"python.analysis.typeCheckingMode": "off"
2+
"python.analysis.typeCheckingMode": "off",
3+
"python.testing.unittestArgs": [
4+
"-v",
5+
"-s",
6+
"./Tests",
7+
"-p",
8+
"*test*.py"
9+
],
10+
"python.testing.pytestEnabled": false,
11+
"python.testing.unittestEnabled": true
312
}

CHANGELOG.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1+
## [0.1.5] - 2022-01-10
2+
3+
### Changed
4+
- seperated CWERMetric to SER and WER Metrics in mltu.metrics, Character/word rate was calculatted in a wrong way
5+
- created @setter for augmentors and transformers in DataProvider, to properlly add augmentors and transformers to the pipeline
6+
- augmentors and transformers must inherit from `mltu.augmentors.base.Augmentor` and `mltu.transformers.base.Transformer` respectively
7+
- updated ImageShowCV2 transformer documentation
8+
- fixed OnnxInferenceModel in mltu.inferenceModels to use CPU even if GPU is available with force_cpu=True flag
9+
10+
### Added:
11+
- added RandomSharpen to mltu.augmentors, used for simple image augmentation;
12+
- added ImageShowCV2 to mltu.transformers, used to show image with cv2 for debugging purposes;
13+
- added better explained documentation
14+
- created unittests for CER and WER in mltu.utils.text_utils and TensorFlow verion of CER and WER mltu.metrics
15+
116
## [0.1.4] - 2022-12-21
17+
218
### Added:
319
- added mltu.augmentors (RandomBrightness, RandomRotate, RandomErodeDilate) - used for simple image augmentation;
420

Tests/test_metrics.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import unittest
2+
import numpy as np
3+
from mltu.metrics import CERMetric, WERMetric
4+
5+
import numpy as np
6+
import tensorflow as tf
7+
8+
class TestMetrics(unittest.TestCase):
9+
10+
def to_embeddings(self, sentences, vocab):
11+
embeddings, max_len = [], 0
12+
13+
for sentence in sentences:
14+
embedding = []
15+
for character in sentence:
16+
embedding.append(vocab.index(character))
17+
embeddings.append(embedding)
18+
max_len = max(max_len, len(embedding))
19+
return embeddings, max_len
20+
21+
def setUp(self) -> None:
22+
true_words = ['Who are you', 'I am a student', 'I am a teacher', 'Just different sentence length']
23+
pred_words = ['Who are you', 'I am a ztudent', 'I am A reacher', 'Just different length']
24+
25+
vocab = set()
26+
for sen in true_words + pred_words:
27+
for character in sen:
28+
vocab.add(character)
29+
self.vocab = "".join(vocab)
30+
31+
sentence_true, max_len_true = self.to_embeddings(true_words, self.vocab)
32+
sentence_pred, max_len_pred = self.to_embeddings(pred_words, self.vocab)
33+
34+
max_len = max(max_len_true, max_len_pred)
35+
padding_length = 64
36+
37+
self.sen_true = [np.pad(sen, (0, max_len - len(sen)), 'constant', constant_values=len(self.vocab)) for sen in sentence_true]
38+
self.sen_pred = [np.pad(sen, (0, padding_length - len(sen)), 'constant', constant_values=-1) for sen in sentence_pred]
39+
40+
def test_CERMetric(self):
41+
vocabulary = tf.constant(list(self.vocab))
42+
cer = CERMetric.get_cer(self.sen_true, self.sen_pred, vocabulary).numpy()
43+
44+
self.assertTrue(np.array_equal(cer, np.array([0.0, 0.071428575, 0.14285715, 0.42857143], dtype=np.float32)))
45+
46+
def test_WERMetric(self):
47+
vocabulary = tf.constant(list(self.vocab))
48+
wer = WERMetric.get_wer(self.sen_true, self.sen_pred, vocabulary).numpy()
49+
50+
self.assertTrue(np.array_equal(wer, np.array([0., 0.25, 0.5, 0.33333334], dtype=np.float32)))
51+
52+
if __name__ == "__main__":
53+
unittest.main()

Tests/test_text_utils.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import unittest
2+
3+
from mltu.utils.text_utils import edit_distance, get_cer, get_wer
4+
5+
class TestTextUtils(unittest.TestCase):
6+
7+
def test_edit_distance(self):
8+
""" This unit test includes several test cases to cover different scenarios, including no errors,
9+
substitution errors, insertion errors, deletion errors, and a more complex case with multiple
10+
errors. It also includes a test case for empty input.
11+
"""
12+
# Test simple case with no errors
13+
prediction_tokens = ['A', 'B', 'C']
14+
reference_tokens = ['A', 'B', 'C']
15+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 0)
16+
17+
# Test simple case with one substitution error
18+
prediction_tokens = ['A', 'B', 'D']
19+
reference_tokens = ['A', 'B', 'C']
20+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 1)
21+
22+
# Test simple case with one insertion error
23+
prediction_tokens = ['A', 'B', 'C']
24+
reference_tokens = ['A', 'B', 'C', 'D']
25+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 1)
26+
27+
# Test simple case with one deletion error
28+
prediction_tokens = ['A', 'B']
29+
reference_tokens = ['A', 'B', 'C']
30+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 1)
31+
32+
# Test more complex case with multiple errors
33+
prediction_tokens = ['A', 'B', 'C', 'D', 'E']
34+
reference_tokens = ['A', 'C', 'B', 'F', 'E']
35+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 3)
36+
37+
# Test empty input
38+
prediction_tokens = []
39+
reference_tokens = []
40+
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 0)
41+
42+
def test_get_cer(self):
43+
# Test simple case with no errors
44+
preds = ['A B C']
45+
target = ['A B C']
46+
self.assertEqual(get_cer(preds, target), 0)
47+
48+
# Test simple case with one character error
49+
preds = ['A B C']
50+
target = ['A B D']
51+
self.assertEqual(get_cer(preds, target), 1/5)
52+
53+
# Test simple case with multiple character errors
54+
preds = ['A B C']
55+
target = ['D E F']
56+
self.assertEqual(get_cer(preds, target), 3/5)
57+
58+
# Test empty input
59+
preds = []
60+
target = []
61+
self.assertEqual(get_cer(preds, target), 0)
62+
63+
# Test simple case with different word lengths
64+
preds = ['ABC']
65+
target = ['ABCDEFG']
66+
self.assertEqual(get_cer(preds, target), 4/7)
67+
68+
def test_get_wer(self):
69+
# Test simple case with no errors
70+
preds = 'A B C'
71+
target = 'A B C'
72+
self.assertEqual(get_wer(preds, target), 0)
73+
74+
# Test simple case with one word error
75+
preds = 'A B C'
76+
target = 'A B D'
77+
self.assertEqual(get_wer(preds, target), 1/3)
78+
79+
# Test simple case with multiple word errors
80+
preds = 'A B C'
81+
target = 'D E F'
82+
self.assertEqual(get_wer(preds, target), 1)
83+
84+
# Test empty input
85+
preds = ""
86+
target = ""
87+
self.assertEqual(get_wer(preds, target), 0)
88+
89+
# Test simple case with different sentence lengths
90+
preds = ['ABC']
91+
target = ['ABC DEF']
92+
self.assertEqual(get_wer(preds, target), 1)
93+
94+
if __name__ == '__main__':
95+
unittest.main()

Tutorials/02_captcha_to_text/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def download_and_unzip(url, extract_to='Datasets'):
6060
# Augment training data with random brightness, rotation and erode/dilate
6161
train_data_provider.augmentors = [RandomBrightness(), RandomRotate(), RandomErodeDilate()]
6262

63+
# Creating TensorFlow model architecture
6364
model = train_model(
6465
input_dim = (configs.height, configs.width, 3),
6566
output_dim = len(configs.vocab),
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import stow
2+
from datetime import datetime
3+
4+
from mltu.configs import BaseModelConfigs
5+
6+
class ModelConfigs(BaseModelConfigs):
7+
def __init__(self):
8+
super().__init__()
9+
self.model_path = stow.join('Models/03_handwriting_recognition', datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
10+
self.vocab = ''
11+
self.height = 32
12+
self.width = 128
13+
self.max_text_length = 0
14+
self.batch_size = 64
15+
self.learning_rate = 0.001
16+
self.train_epochs = 1000
17+
self.train_workers = 20
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import cv2
2+
import typing
3+
import numpy as np
4+
5+
from mltu.inferenceModel import OnnxInferenceModel
6+
from mltu.utils.text_utils import ctc_decoder, get_cer
7+
8+
class ImageToWordModel(OnnxInferenceModel):
9+
def __init__(self, char_list: typing.Union[str, list], *args, **kwargs):
10+
super().__init__(*args, **kwargs)
11+
self.char_list = char_list
12+
13+
def predict(self, image: np.ndarray):
14+
image = cv2.resize(image, self.input_shape[:2][::-1])
15+
16+
image_pred = np.expand_dims(image, axis=0).astype(np.float32)
17+
18+
preds = self.model.run(None, {self.input_name: image_pred})[0]
19+
20+
text = ctc_decoder(preds, self.char_list)[0]
21+
22+
return text
23+
24+
if __name__ == "__main__":
25+
import pandas as pd
26+
from tqdm import tqdm
27+
from mltu.configs import BaseModelConfigs
28+
29+
configs = BaseModelConfigs.load("Models/03_handwriting_recognition/202212290905/configs.yaml")
30+
31+
model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab)
32+
33+
df = pd.read_csv("Models/03_handwriting_recognition/202212290905/val.csv").values.tolist()
34+
35+
accum_cer = []
36+
for image_path, label in tqdm(df):
37+
image = cv2.imread(image_path)
38+
39+
prediction_text = model.predict(image)
40+
41+
cer = get_cer(prediction_text, label)
42+
print(f"Image: {image_path}, Label: {label}, Prediction: {prediction_text}, CER: {cer}")
43+
44+
accum_cer.append(cer)
45+
46+
print(f"Average CER: {np.average(accum_cer)}")
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from keras import layers
2+
from keras.models import Model
3+
4+
from mltu.model_utils import residual_block
5+
6+
def train_model(input_dim, output_dim, activation='leaky_relu', dropout=0.2):
7+
8+
inputs = layers.Input(shape=input_dim, name="input")
9+
10+
# normalize images here instead in preprocessing step
11+
input = layers.Lambda(lambda x: x / 255)(inputs)
12+
13+
x1 = residual_block(input, 16, activation=activation, skip_conv=True, strides=1, dropout=dropout)
14+
15+
x2 = residual_block(x1, 16, activation=activation, skip_conv=True, strides=2, dropout=dropout)
16+
x3 = residual_block(x2, 16, activation=activation, skip_conv=False, strides=1, dropout=dropout)
17+
18+
x4 = residual_block(x3, 32, activation=activation, skip_conv=True, strides=2, dropout=dropout)
19+
x5 = residual_block(x4, 32, activation=activation, skip_conv=False, strides=1, dropout=dropout)
20+
21+
x6 = residual_block(x5, 64, activation=activation, skip_conv=True, strides=2, dropout=dropout)
22+
x7 = residual_block(x6, 64, activation=activation, skip_conv=True, strides=1, dropout=dropout)
23+
24+
x8 = residual_block(x7, 64, activation=activation, skip_conv=False, strides=1, dropout=dropout)
25+
x9 = residual_block(x8, 64, activation=activation, skip_conv=False, strides=1, dropout=dropout)
26+
27+
squeezed = layers.Reshape((x9.shape[-3] * x9.shape[-2], x9.shape[-1]))(x9)
28+
29+
blstm = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(squeezed)
30+
blstm = layers.Dropout(dropout)(blstm)
31+
32+
output = layers.Dense(output_dim + 1, activation='softmax', name="output")(blstm)
33+
34+
model = Model(inputs=inputs, outputs=output)
35+
return model

0 commit comments

Comments
 (0)