Skip to content

Commit 4facd52

Browse files
Merge pull request #7 from seidnerj/main
Additional functionality + code styling fixes Ok, now merging to develop and later will do deeper checks
2 parents 2ce972e + 6d5b14a commit 4facd52

File tree

47 files changed

+662
-378
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+662
-378
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@ Datasets/*
77
Models/*
88
dist
99

10-
!*.md
10+
!*.md
11+
12+
.idea
13+
.python-version

Tests/test_text_utils.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,28 @@ def test_edit_distance(self):
1010
errors. It also includes a test case for empty input.
1111
"""
1212
# Test simple case with no errors
13-
prediction_tokens = ['A', 'B', 'C']
14-
reference_tokens = ['A', 'B', 'C']
13+
prediction_tokens = ["A", "B", "C"]
14+
reference_tokens = ["A", "B", "C"]
1515
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 0)
1616

1717
# Test simple case with one substitution error
18-
prediction_tokens = ['A', 'B', 'D']
19-
reference_tokens = ['A', 'B', 'C']
18+
prediction_tokens = ["A", "B", "D"]
19+
reference_tokens = ["A", "B", "C"]
2020
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 1)
2121

2222
# Test simple case with one insertion error
23-
prediction_tokens = ['A', 'B', 'C']
24-
reference_tokens = ['A', 'B', 'C', 'D']
23+
prediction_tokens = ["A", "B", "C"]
24+
reference_tokens = ["A", "B", "C", "D"]
2525
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 1)
2626

2727
# Test simple case with one deletion error
28-
prediction_tokens = ['A', 'B']
29-
reference_tokens = ['A', 'B', 'C']
28+
prediction_tokens = ["A", "B"]
29+
reference_tokens = ["A", "B", "C"]
3030
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 1)
3131

3232
# Test more complex case with multiple errors
33-
prediction_tokens = ['A', 'B', 'C', 'D', 'E']
34-
reference_tokens = ['A', 'C', 'B', 'F', 'E']
33+
prediction_tokens = ["A", "B", "C", "D", "E"]
34+
reference_tokens = ["A", "C", "B", "F", "E"]
3535
self.assertEqual(edit_distance(prediction_tokens, reference_tokens), 3)
3636

3737
# Test empty input
@@ -41,18 +41,18 @@ def test_edit_distance(self):
4141

4242
def test_get_cer(self):
4343
# Test simple case with no errors
44-
preds = ['A B C']
45-
target = ['A B C']
44+
preds = ["A B C"]
45+
target = ["A B C"]
4646
self.assertEqual(get_cer(preds, target), 0)
4747

4848
# Test simple case with one character error
49-
preds = ['A B C']
50-
target = ['A B D']
49+
preds = ["A B C"]
50+
target = ["A B D"]
5151
self.assertEqual(get_cer(preds, target), 1/5)
5252

5353
# Test simple case with multiple character errors
54-
preds = ['A B C']
55-
target = ['D E F']
54+
preds = ["A B C"]
55+
target = ["D E F"]
5656
self.assertEqual(get_cer(preds, target), 3/5)
5757

5858
# Test empty input
@@ -61,24 +61,24 @@ def test_get_cer(self):
6161
self.assertEqual(get_cer(preds, target), 0)
6262

6363
# Test simple case with different word lengths
64-
preds = ['ABC']
65-
target = ['ABCDEFG']
64+
preds = ["ABC"]
65+
target = ["ABCDEFG"]
6666
self.assertEqual(get_cer(preds, target), 4/7)
6767

6868
def test_get_wer(self):
6969
# Test simple case with no errors
70-
preds = 'A B C'
71-
target = 'A B C'
70+
preds = "A B C"
71+
target = "A B C"
7272
self.assertEqual(get_wer(preds, target), 0)
7373

7474
# Test simple case with one word error
75-
preds = 'A B C'
76-
target = 'A B D'
75+
preds = "A B C"
76+
target = "A B D"
7777
self.assertEqual(get_wer(preds, target), 1/3)
7878

7979
# Test simple case with multiple word errors
80-
preds = 'A B C'
81-
target = 'D E F'
80+
preds = "A B C"
81+
target = "D E F"
8282
self.assertEqual(get_wer(preds, target), 1)
8383

8484
# Test empty input
@@ -87,9 +87,10 @@ def test_get_wer(self):
8787
self.assertEqual(get_wer(preds, target), 0)
8888

8989
# Test simple case with different sentence lengths
90-
preds = ['ABC']
91-
target = ['ABC DEF']
90+
preds = ["ABC"]
91+
target = ["ABC DEF"]
9292
self.assertEqual(get_wer(preds, target), 1)
9393

94-
if __name__ == '__main__':
95-
unittest.main()
94+
95+
if __name__ == "__main__":
96+
unittest.main()

Tutorials/01_image_to_word/configs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
from mltu.configs import BaseModelConfigs
55

6+
67
class ModelConfigs(BaseModelConfigs):
78
def __init__(self):
89
super().__init__()
9-
self.model_path = os.path.join('Models/1_image_to_word', datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
10-
self.vocab = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
10+
self.model_path = os.path.join("Models/1_image_to_word", datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
11+
self.vocab = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
1112
self.height = 32
1213
self.width = 128
1314
self.max_text_length = 23

Tutorials/01_image_to_word/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from mltu.tensorflow.model_utils import residual_block
55

6-
def train_model(input_dim, output_dim, activation='leaky_relu', dropout=0.2):
6+
7+
def train_model(input_dim, output_dim, activation="leaky_relu", dropout=0.2):
78

89
inputs = layers.Input(shape=input_dim, name="input")
910

@@ -24,7 +25,7 @@ def train_model(input_dim, output_dim, activation='leaky_relu', dropout=0.2):
2425

2526
blstm = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(squeezed)
2627

27-
output = layers.Dense(output_dim + 1, activation='softmax', name="output")(blstm)
28+
output = layers.Dense(output_dim + 1, activation="softmax", name="output")(blstm)
2829

2930
model = Model(inputs=inputs, outputs=output)
3031
return model

Tutorials/01_image_to_word/train.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22
from tqdm import tqdm
33
import tensorflow as tf
44

5-
try: [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices('GPU')]
5+
try: [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices("GPU")]
66
except: pass
77

88
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
99

1010
from mltu.dataProvider import DataProvider
1111
from mltu.preprocessors import ImageReader
12+
from mltu.annotations.images import CVImage
1213
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding
1314
from mltu.tensorflow.losses import CTCloss
1415
from mltu.tensorflow.callbacks import Model2onnx, TrainLogger
1516
from mltu.tensorflow.metrics import CWERMetric
1617

18+
1719
from model import train_model
1820
from configs import ModelConfigs
1921

@@ -49,7 +51,7 @@ def read_annotation_file(annotation_path):
4951
dataset=train_dataset,
5052
skip_validation=True,
5153
batch_size=configs.batch_size,
52-
data_preprocessors=[ImageReader()],
54+
data_preprocessors=[ImageReader(CVImage)],
5355
transformers=[
5456
ImageResizer(configs.width, configs.height),
5557
LabelIndexer(configs.vocab),
@@ -62,7 +64,7 @@ def read_annotation_file(annotation_path):
6264
dataset=val_dataset,
6365
skip_validation=True,
6466
batch_size=configs.batch_size,
65-
data_preprocessors=[ImageReader()],
67+
data_preprocessors=[ImageReader(CVImage)],
6668
transformers=[
6769
ImageResizer(configs.width, configs.height),
6870
LabelIndexer(configs.vocab),
@@ -87,11 +89,11 @@ def read_annotation_file(annotation_path):
8789
os.makedirs(configs.model_path, exist_ok=True)
8890

8991
# Define callbacks
90-
earlystopper = EarlyStopping(monitor='val_CER', patience=10, verbose=1)
91-
checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor='val_CER', verbose=1, save_best_only=True, mode='min')
92+
earlystopper = EarlyStopping(monitor="val_CER", patience=10, verbose=1)
93+
checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor="val_CER", verbose=1, save_best_only=True, mode="min")
9294
trainLogger = TrainLogger(configs.model_path)
93-
tb_callback = TensorBoard(f'{configs.model_path}/logs', update_freq=1)
94-
reduceLROnPlat = ReduceLROnPlateau(monitor='val_CER', factor=0.9, min_delta=1e-10, patience=5, verbose=1, mode='auto')
95+
tb_callback = TensorBoard(f"{configs.model_path}/logs", update_freq=1)
96+
reduceLROnPlat = ReduceLROnPlateau(monitor="val_CER", factor=0.9, min_delta=1e-10, patience=5, verbose=1, mode="auto")
9597
model2onnx = Model2onnx(f"{configs.model_path}/model.h5")
9698

9799
# Train the model
@@ -104,5 +106,5 @@ def read_annotation_file(annotation_path):
104106
)
105107

106108
# Save training and validation datasets as csv files
107-
train_data_provider.to_csv(os.path.join(configs.model_path, 'train.csv'))
108-
val_data_provider.to_csv(os.path.join(configs.model_path, 'val.csv'))
109+
train_data_provider.to_csv(os.path.join(configs.model_path, "train.csv"))
110+
val_data_provider.to_csv(os.path.join(configs.model_path, "val.csv"))

Tutorials/02_captcha_to_text/configs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
from mltu.configs import BaseModelConfigs
55

6+
67
class ModelConfigs(BaseModelConfigs):
78
def __init__(self):
89
super().__init__()
9-
self.model_path = os.path.join('Models/02_captcha_to_text', datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
10-
self.vocab = ''
10+
self.model_path = os.path.join("Models/02_captcha_to_text", datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
11+
self.vocab = ""
1112
self.height = 50
1213
self.width = 200
1314
self.max_text_length = 0

Tutorials/02_captcha_to_text/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from mltu.tensorflow.model_utils import residual_block
55

6-
def train_model(input_dim, output_dim, activation='leaky_relu', dropout=0.2):
6+
7+
def train_model(input_dim, output_dim, activation="leaky_relu", dropout=0.2):
78

89
inputs = layers.Input(shape=input_dim, name="input")
910

@@ -29,7 +30,7 @@ def train_model(input_dim, output_dim, activation='leaky_relu', dropout=0.2):
2930
blstm = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(squeezed)
3031
blstm = layers.Dropout(dropout)(blstm)
3132

32-
output = layers.Dense(output_dim + 1, activation='softmax', name="output")(blstm)
33+
output = layers.Dense(output_dim + 1, activation="softmax", name="output")(blstm)
3334

3435
model = Model(inputs=inputs, outputs=output)
35-
return model
36+
return model

Tutorials/02_captcha_to_text/train.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import tensorflow as tf
2-
try: [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices('GPU')]
2+
try: [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices("GPU")]
33
except: pass
44

55
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
@@ -12,6 +12,7 @@
1212
from mltu.preprocessors import ImageReader
1313
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding
1414
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate
15+
from mltu.annotations.images import CVImage
1516

1617
from model import train_model
1718
from configs import ModelConfigs
@@ -21,17 +22,20 @@
2122
from io import BytesIO
2223
from zipfile import ZipFile
2324

24-
def download_and_unzip(url, extract_to='Datasets'):
25+
26+
def download_and_unzip(url, extract_to="Datasets"):
2527
http_response = urlopen(url)
2628
zipfile = ZipFile(BytesIO(http_response.read()))
2729
zipfile.extractall(path=extract_to)
2830

29-
if not os.path.exists(os.path.join('Datasets', 'captcha_images_v2')):
30-
download_and_unzip('https://github.com/AakashKumarNain/CaptchaCracker/raw/master/captcha_images_v2.zip', extract_to='Datasets')
31+
32+
if not os.path.exists(os.path.join("Datasets", "captcha_images_v2")):
33+
download_and_unzip("https://github.com/AakashKumarNain/CaptchaCracker/raw/master/captcha_images_v2.zip",
34+
extract_to="Datasets")
3135

3236
# Create a list of all the images and labels in the dataset
3337
dataset, vocab, max_len = [], set(), 0
34-
captcha_path = os.path.join('Datasets', 'captcha_images_v2')
38+
captcha_path = os.path.join("Datasets", "captcha_images_v2")
3539
for file in os.listdir(captcha_path):
3640
file_path = os.path.join(captcha_path, file)
3741
label = os.path.splitext(file)[0] # Get the file name without the extension
@@ -51,7 +55,7 @@ def download_and_unzip(url, extract_to='Datasets'):
5155
dataset=dataset,
5256
skip_validation=True,
5357
batch_size=configs.batch_size,
54-
data_preprocessors=[ImageReader()],
58+
data_preprocessors=[ImageReader(CVImage)],
5559
transformers=[
5660
ImageResizer(configs.width, configs.height),
5761
LabelIndexer(configs.vocab),
@@ -82,11 +86,11 @@ def download_and_unzip(url, extract_to='Datasets'):
8286
os.makedirs(configs.model_path, exist_ok=True)
8387

8488
# Define callbacks
85-
earlystopper = EarlyStopping(monitor='val_CER', patience=50, verbose=1)
86-
checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor='val_CER', verbose=1, save_best_only=True, mode='min')
89+
earlystopper = EarlyStopping(monitor="val_CER", patience=50, verbose=1)
90+
checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor="val_CER", verbose=1, save_best_only=True, mode="min")
8791
trainLogger = TrainLogger(configs.model_path)
88-
tb_callback = TensorBoard(f'{configs.model_path}/logs', update_freq=1)
89-
reduceLROnPlat = ReduceLROnPlateau(monitor='val_CER', factor=0.9, min_delta=1e-10, patience=20, verbose=1, mode='auto')
92+
tb_callback = TensorBoard(f"{configs.model_path}/logs", update_freq=1)
93+
reduceLROnPlat = ReduceLROnPlateau(monitor="val_CER", factor=0.9, min_delta=1e-10, patience=20, verbose=1, mode="auto")
9094
model2onnx = Model2onnx(f"{configs.model_path}/model.h5")
9195

9296
# Train the model
@@ -99,5 +103,5 @@ def download_and_unzip(url, extract_to='Datasets'):
99103
)
100104

101105
# Save training and validation datasets as csv files
102-
train_data_provider.to_csv(os.path.join(configs.model_path, 'train.csv'))
103-
val_data_provider.to_csv(os.path.join(configs.model_path, 'val.csv'))
106+
train_data_provider.to_csv(os.path.join(configs.model_path, "train.csv"))
107+
val_data_provider.to_csv(os.path.join(configs.model_path, "val.csv"))

Tutorials/03_handwriting_recognition/configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
class ModelConfigs(BaseModelConfigs):
77
def __init__(self):
88
super().__init__()
9-
self.model_path = os.path.join('Models/03_handwriting_recognition', datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
10-
self.vocab = ''
9+
self.model_path = os.path.join("Models/03_handwriting_recognition", datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
10+
self.vocab = ""
1111
self.height = 32
1212
self.width = 128
1313
self.max_text_length = 0

Tutorials/03_handwriting_recognition/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from mltu.tensorflow.model_utils import residual_block
55

6-
def train_model(input_dim, output_dim, activation='leaky_relu', dropout=0.2):
6+
7+
def train_model(input_dim, output_dim, activation="leaky_relu", dropout=0.2):
78

89
inputs = layers.Input(shape=input_dim, name="input")
910

@@ -29,7 +30,7 @@ def train_model(input_dim, output_dim, activation='leaky_relu', dropout=0.2):
2930
blstm = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(squeezed)
3031
blstm = layers.Dropout(dropout)(blstm)
3132

32-
output = layers.Dense(output_dim + 1, activation='softmax', name="output")(blstm)
33+
output = layers.Dense(output_dim + 1, activation="softmax", name="output")(blstm)
3334

3435
model = Model(inputs=inputs, outputs=output)
35-
return model
36+
return model

0 commit comments

Comments
 (0)