Skip to content

Commit 29cad5a

Browse files
committed
Finished 4th tutorial code, fixes CER and WER metrics, wrote some tests for them
1 parent 57bf2da commit 29cad5a

File tree

9 files changed

+67
-76
lines changed

9 files changed

+67
-76
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
## [0.1.5] - 2022-01-03
22

33
### Changed
4-
- changed CWERMetric in mltu.metrics, Character/word rate was calculatted in a wrong way
4+
- seperated CWERMetric to SER and WER Metrics in mltu.metrics, Character/word rate was calculatted in a wrong way
55
- created @setter for augmentors and transformers in DataProvider, to properlly add augmentors and transformers to the pipeline
66
- augmentors and transformers must inherit from `mltu.augmentors.base.Augmentor` and `mltu.transformers.base.Transformer` respectively
7-
- added better explained documentation
87

98
### Added:
109
- added RandomSharpen to mltu.augmentors, used for simple image augmentation;
1110
- added ImageShowCV2 to mltu.transformers, used to show image with cv2 for debugging purposes;
11+
- added better explained documentation
12+
- created unittests for CER and WER in mltu.utils.text_utils and TensorFlow verion of CER and WER mltu.metrics
1213

1314
## [0.1.4] - 2022-12-21
1415

Tests/test_metrics.py

Lines changed: 35 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,53 @@
1+
import unittest
12
import numpy as np
23
from mltu.metrics import CERMetric, WERMetric
34

4-
from mltu.utils.text_utils import get_wer as wer
5-
6-
import cv2
7-
import typing
85
import numpy as np
96
import tensorflow as tf
107

11-
if __name__ == "__main__":
12-
import pandas as pd
13-
from tqdm import tqdm
14-
)
15-
16-
8+
class TestMetrics(unittest.TestCase):
179

18-
# sentences_true = ['helo love', 'helo home', 'helo world']
19-
# sentences_pred = ['helo python', 'helo home', 'helo python here']
10+
def to_embeddings(self, sentences, vocab):
11+
embeddings, max_len = [], 0
2012

21-
# def to_embeddings(sentences, vocab):
22-
# embeddings, max_len = [], 0
13+
for sentence in sentences:
14+
embedding = []
15+
for character in sentence:
16+
embedding.append(vocab.index(character))
17+
embeddings.append(embedding)
18+
max_len = max(max_len, len(embedding))
19+
return embeddings, max_len
2320

24-
# for sentence in sentences:
25-
# embedding = []
26-
# for character in sentence:
27-
# embedding.append(vocab.index(character))
28-
# embeddings.append(embedding)
29-
# max_len = max(max_len, len(embedding))
30-
# return embeddings, max_len
21+
def setUp(self) -> None:
22+
true_words = ['Who are you', 'I am a student', 'I am a teacher', 'Just different sentence length']
23+
pred_words = ['Who are you', 'I am a ztudent', 'I am A reacher', 'Just different length']
3124

32-
# vocab = set()
33-
# for sen in sentences_true + sentences_pred:
34-
# for character in sen:
35-
# vocab.add(character)
36-
# vocab = "".join(vocab)
25+
vocab = set()
26+
for sen in true_words + pred_words:
27+
for character in sen:
28+
vocab.add(character)
29+
self.vocab = "".join(vocab)
3730

38-
# sen1, max_len = to_embeddings(sentences_true, vocab)
39-
# sen2, _ = to_embeddings(sentences_pred, vocab)
31+
sentence_true, max_len_true = self.to_embeddings(true_words, self.vocab)
32+
sentence_pred, max_len_pred = self.to_embeddings(pred_words, self.vocab)
4033

41-
# sen_true = [np.pad(sen, (0, max_len - len(sen)), 'constant', constant_values=len(vocab)) for sen in sen1]
42-
# sen_pred = [np.pad(sen, (0, 24 - len(sen)), 'constant', constant_values=-1) for sen in sen2]
34+
max_len = max(max_len_true, max_len_pred)
35+
padding_length = 64
4336

37+
self.sen_true = [np.pad(sen, (0, max_len - len(sen)), 'constant', constant_values=len(self.vocab)) for sen in sentence_true]
38+
self.sen_pred = [np.pad(sen, (0, padding_length - len(sen)), 'constant', constant_values=-1) for sen in sentence_pred]
4439

45-
# tf_vocab = tf.constant(list(vocab))
40+
def test_CERMetric(self):
41+
vocabulary = tf.constant(list(self.vocab))
42+
cer = CERMetric.get_cer(self.sen_true, self.sen_pred, vocabulary).numpy()
4643

47-
# distance = WERMetric.get_wer(sen_pred, sen_true, vocab=tf_vocab)
44+
self.assertTrue(np.array_equal(cer, np.array([0.0, 0.071428575, 0.14285715, 0.42857143], dtype=np.float32)))
4845

49-
# d = wer(sentences_pred, sentences_true)
46+
def test_WERMetric(self):
47+
vocabulary = tf.constant(list(self.vocab))
48+
wer = WERMetric.get_wer(self.sen_true, self.sen_pred, vocabulary).numpy()
5049

51-
# print(list(distance.numpy()))
52-
# print(d)
50+
self.assertTrue(np.array_equal(wer, np.array([0., 0.25, 0.5, 0.33333334], dtype=np.float32)))
5351

54-
55-
word_true = [
56-
[1, 2, 3, 4, 5, 6, 1],
57-
[2, 3, 4, 5, 6, 1, 1]
58-
]
59-
word_pred = [
60-
[1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
61-
[2, 3, 4, 5, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
62-
]
63-
vocabulary = tf.constant(list("abcdefg"))
64-
65-
distance = CERMetric.get_cer(word_pred, word_true, vocabulary)
52+
if __name__ == "__main__":
53+
unittest.main()

Tests/test_text_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,27 +69,27 @@ def test_get_wer(self):
6969
# Test simple case with no errors
7070
preds = 'A B C'
7171
target = 'A B C'
72-
self.assertEqual(get_wer(preds, target), [0, 0, 0])
72+
self.assertEqual(get_wer(preds, target), 0)
7373

7474
# Test simple case with one word error
7575
preds = 'A B C'
7676
target = 'A B D'
77-
self.assertEqual(get_wer(preds, target), [0, 0, 1])
77+
self.assertEqual(get_wer(preds, target), 1/3)
7878

7979
# Test simple case with multiple word errors
8080
preds = 'A B C'
8181
target = 'D E F'
82-
self.assertEqual(get_wer(preds, target), [1, 1, 1])
82+
self.assertEqual(get_wer(preds, target), 1)
8383

8484
# Test empty input
8585
preds = ""
8686
target = ""
87-
self.assertEqual(get_wer(preds, target), [])
87+
self.assertEqual(get_wer(preds, target), 0)
8888

8989
# Test simple case with different sentence lengths
9090
preds = ['ABC']
9191
target = ['ABC DEF']
92-
self.assertEqual(get_wer(preds, target), [1/2])
92+
self.assertEqual(get_wer(preds, target), 1)
9393

9494
if __name__ == '__main__':
9595
unittest.main()

Tutorials/04_sentence_recognition/configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ def __init__(self):
1212
self.width = 1408
1313
self.max_text_length = 0
1414
self.batch_size = 32
15-
self.learning_rate = 0.003
15+
self.learning_rate = 0.001
1616
self.train_epochs = 1000
1717
self.train_workers = 20

Tutorials/04_sentence_recognition/inferenceModel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ def predict(self, image: np.ndarray):
2626
from tqdm import tqdm
2727
from mltu.configs import BaseModelConfigs
2828

29-
configs = BaseModelConfigs.load("Models/04_sentence_recognition/202301041513/configs.yaml")
29+
configs = BaseModelConfigs.load("Models/04_sentence_recognition/202301060816/configs.yaml")
3030

3131
model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab)
3232

33-
df = pd.read_csv("Models/04_sentence_recognition/202301041513/val.csv").values.tolist()
33+
df = pd.read_csv("Models/04_sentence_recognition/202301060816/val.csv").values.tolist()
3434

3535
accum_cer, accum_wer = [], []
3636
for image_path, label in tqdm(df):

Tutorials/04_sentence_recognition/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ def train_model(input_dim, output_dim, activation='leaky_relu', dropout=0.2):
2626

2727
squeezed = layers.Reshape((x9.shape[-3] * x9.shape[-2], x9.shape[-1]))(x9)
2828

29-
blstm = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(squeezed)
29+
blstm = layers.Bidirectional(layers.LSTM(256, return_sequences=True))(squeezed)
30+
blstm = layers.Dropout(dropout)(blstm)
31+
32+
blstm = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(blstm)
3033
blstm = layers.Dropout(dropout)(blstm)
3134

3235
output = layers.Dense(output_dim + 1, activation='softmax', name="output")(blstm)

Tutorials/04_sentence_recognition/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen
1111
from mltu.losses import CTCloss
1212
from mltu.callbacks import Model2onnx, TrainLogger
13-
from mltu.metrics import CWERMetric
13+
from mltu.metrics import CERMetric, WERMetric
1414

1515
from model import train_model
1616
from configs import ModelConfigs
@@ -89,7 +89,10 @@
8989
model.compile(
9090
optimizer=tf.keras.optimizers.Adam(learning_rate=configs.learning_rate),
9191
loss=CTCloss(),
92-
metrics=[CWERMetric(padding_token=len(configs.vocab))],
92+
metrics=[
93+
CERMetric(vocabulary=configs.vocab),
94+
WERMetric(vocabulary=configs.vocab)
95+
],
9396
run_eagerly=False
9497
)
9598
model.summary(line_length=110)

mltu/metrics.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class CERMetric(tf.keras.metrics.Metric):
7878
name: (Optional) string name of the metric instance.
7979
**kwargs: Additional keyword arguments.
8080
"""
81-
def __init__(self, vocabulary, name='CWER', **kwargs):
81+
def __init__(self, vocabulary, name='CER', **kwargs):
8282
# Initialize the base Metric class
8383
super(CERMetric, self).__init__(name=name, **kwargs)
8484

@@ -103,11 +103,12 @@ def get_cer(pred_decoded, y_true, vocab, padding=-1):
103103
tf.Tensor: The CER between the predicted labels and true labels
104104
"""
105105
# Keep only valid indices in the predicted labels tensor, replacing invalid indices with padding token
106-
valid_pred_indices = tf.less(pred_decoded, tf.shape(vocab)[0])
106+
vocab_length = tf.cast(tf.shape(vocab)[0], tf.int64)
107+
valid_pred_indices = tf.less(pred_decoded, vocab_length)
107108
valid_pred = tf.where(valid_pred_indices, pred_decoded, padding)
108109

109110
# Keep only valid indices in the true labels tensor, replacing invalid indices with padding token
110-
valid_true_indices = tf.less(y_true, tf.shape(vocab)[0])
111+
valid_true_indices = tf.less(y_true, vocab_length)
111112
valid_true = tf.where(valid_true_indices, y_true, padding)
112113

113114
# Convert the valid predicted labels tensor to a sparse tensor
@@ -186,7 +187,8 @@ def preprocess_dense(dense_input: tf.Tensor, vocab: tf.Tensor, padding=-1) -> tf
186187
tf.SparseTensor: The sparse tensor with given vocabulary
187188
"""
188189
# Keep only the valid indices of the dense input tensor
189-
valid_indices = tf.less(dense_input, tf.shape(vocab)[0])
190+
vocab_length = tf.cast(tf.shape(vocab)[0], tf.int64)
191+
valid_indices = tf.less(dense_input, vocab_length)
190192
valid_input = tf.where(valid_indices, dense_input, padding)
191193

192194
# Convert the valid input tensor to a ragged tensor with padding

mltu/utils/text_utils.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def get_cer(
8282
return 0.0
8383

8484
cer = errors / total
85+
8586
return cer
8687

8788
def get_wer(
@@ -102,20 +103,13 @@ def get_wer(
102103
if isinstance(target, str):
103104
target = target.split()
104105

105-
assert len(preds) == len(target), 'preds and target must have the same length'
106-
107-
wer = []
108-
for pred, tgt in zip(preds, target):
109-
errors = edit_distance(pred.split(), tgt.split())
110-
total_words = len(tgt.split())
106+
errors = edit_distance(preds, target)
107+
total_words = len(target)
111108

112-
if total_words == 0:
113-
wer.append(0)
114-
continue
115-
116-
wer.append(errors / total_words)
109+
if total_words == 0:
110+
return 0.0
117111

118-
return wer
112+
return errors / total_words
119113

120114
if __name__ == '__main__':
121115
c1 = 'ROKAS'

0 commit comments

Comments
 (0)