Skip to content

Commit f33fdb8

Browse files
committed
Merge branch 'feature/pytorch-handwritting-recognition' into develop
2 parents a925d48 + 2298650 commit f33fdb8

File tree

22 files changed

+758
-37
lines changed

22 files changed

+758
-37
lines changed

CHANGELOG.md

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
## [1.0.2] - 2022-03-... (unreleased)
2+
### Changed
3+
- changes `OnnxInferenceModel` in `mltu.torch.inferenceModels` to load custom metadata from saved ONNX model
4+
- improved `mltu.dataProvider` to remove bad samples from dataset on epoch end
5+
6+
### Added:
7+
- added `mltu.torch.losses`, used to create PyTorch losses, that may be used in training and validation
8+
- added CTC loss to `mltu.torch.losses` that can be used for training CTC based models
9+
- added `Model2onnx` and `Tensorboard` callbacks to `mltu.torch.callbacks`, used to create PyTorch callbacks, that may be used in training and validation
10+
- added `CERMetric` and `WERMetric` to `mltu.torch.metrics`, used to create PyTorch metrics, that may be used in training and validation
11+
- created 08 pytorch tutorial, that shows how to use `mltu.torch` to train CTC based models
12+
13+
114
## [1.0.1] - 2022-03-06
215
### Changed
316
- In all tutorials removed stow dependency and replaced with os package, to make it easier to use on Windows 11
@@ -25,7 +38,7 @@
2538
-
2639
### Added:
2740
- added 05_sound_to_text tutorial
28-
- added WavReader to mltu/preprocessors, used to read wav files and convert them to numpy arrays
41+
- added `WavReader` to `mltu/preprocessors`, used to read wav files and convert them to numpy arrays
2942

3043

3144
## [0.1.7] - 2022-02-03
@@ -35,11 +48,11 @@
3548

3649
## [0.1.5] - 2022-01-10
3750
### Changed
38-
- seperated CWERMetric to SER and WER Metrics in mltu.metrics, Character/word rate was calculatted in a wrong way
51+
- seperated `CWERMetric` to `CER` and `WER` Metrics in `mltu.metrics`, Character/word rate was calculatted in a wrong way
3952
- created @setter for augmentors and transformers in DataProvider, to properlly add augmentors and transformers to the pipeline
4053
- augmentors and transformers must inherit from `mltu.augmentors.base.Augmentor` and `mltu.transformers.base.Transformer` respectively
4154
- updated ImageShowCV2 transformer documentation
42-
- fixed OnnxInferenceModel in mltu.inferenceModels to use CPU even if GPU is available with force_cpu=True flag
55+
- fixed OnnxInferenceModel in `mltu.inferenceModels` to use CPU even if GPU is available with force_cpu=True flag
4356

4457
### Added:
4558
- added RandomSharpen to mltu.augmentors, used for simple image augmentation;

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ Each tutorial has its own requirements.txt file for a specific mltu version. As
2222
4. [Handwritten sentence recognition with TensorFlow](https://pylessons.com/handwritten-sentence-recognition), code in ```Tutorials\04_sentence_recognition``` folder;
2323
5. [Introduction to speech recognition with TensorFlow](https://pylessons.com/speech-recognition), code in ```Tutorials\05_speech_recognition``` folder;
2424
6. [Introduction to PyTorch in a practical way](https://pylessons.com/pytorch-introduction), code in ```Tutorials\06_pytorch_introduction``` folder;
25-
7. [Using custom wrapper to simplify PyTorch models training pipeline](https://pylessons.com/pytorch-introduction), code in ```Tutorials\07_pytorch_wrapper``` folder;
25+
7. [Using custom wrapper to simplify PyTorch models training pipeline](https://pylessons.com/pytorch-introduction), code in ```Tutorials\07_pytorch_wrapper``` folder;
26+
8. [Handwriting words recognition with PyTorch](https://pylessons.com/handwriting-recognition-pytorch), code in ```Tutorials\08_handwriting_recognition_torch``` folder;

Tutorials/02_captcha_to_text/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ def download_and_unzip(url, extract_to='Datasets'):
3434
captcha_path = os.path.join('Datasets', 'captcha_images_v2')
3535
for file in os.listdir(captcha_path):
3636
file_path = os.path.join(captcha_path, file)
37-
file_name = os.path.splitext(file)[0]
38-
dataset.append([file_path, file_name])
39-
vocab.update(list(file_name))
40-
max_len = max(max_len, len(file_name))
37+
label = os.path.splitext(file)[0] # Get the file name without the extension
38+
dataset.append([file_path, label])
39+
vocab.update(list(label))
40+
max_len = max(max_len, len(label))
4141

4242
configs = ModelConfigs()
4343

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Using custom wrapper to simplify PyTorch models training pipeline
2+
### Construct an accurate handwriting recognition model with PyTorch! Understand how to use MLTU package, to simplify the PyTorch models training pipeline, and discover methods to enhance your model's accuracy!<br><br>
3+
4+
# **Detailed tutorial**:
5+
### [Handwriting words recognition with PyTorch](https://pylessons.com/handwriting-recognition-pytorch)
6+
7+
<p align="center">
8+
<img src="https://pylessons.com/media/Tutorials/mltu/handwriting-recognition-pytorch/handwriting-recognition-pytorch.png">
9+
</p>
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os
2+
from datetime import datetime
3+
4+
from mltu.configs import BaseModelConfigs
5+
6+
class ModelConfigs(BaseModelConfigs):
7+
def __init__(self):
8+
super().__init__()
9+
self.model_path = os.path.join('Models/08_handwriting_recognition_torch', datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
10+
self.vocab = ''
11+
self.height = 32
12+
self.width = 128
13+
self.max_text_length = 0
14+
self.batch_size = 64
15+
self.learning_rate = 0.002
16+
self.train_epochs = 1000
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import cv2
2+
import typing
3+
import numpy as np
4+
5+
from mltu.inferenceModel import OnnxInferenceModel
6+
from mltu.utils.text_utils import ctc_decoder, get_cer
7+
8+
class ImageToWordModel(OnnxInferenceModel):
9+
def __init__(self, *args, **kwargs):
10+
super().__init__(*args, **kwargs)
11+
12+
def predict(self, image: np.ndarray):
13+
image = cv2.resize(image, self.input_shape[:2][::-1])
14+
15+
image_pred = np.expand_dims(image, axis=0).astype(np.float32)
16+
17+
preds = self.model.run(None, {self.input_name: image_pred})[0]
18+
19+
text = ctc_decoder(preds, self.vocab)[0]
20+
21+
return text
22+
23+
if __name__ == "__main__":
24+
import pandas as pd
25+
from tqdm import tqdm
26+
27+
model = ImageToWordModel(model_path="Models/08_handwriting_recognition_torch/202303142139/model.onnx")
28+
29+
df = pd.read_csv("Models/08_handwriting_recognition_torch/202303142139/val.csv").values.tolist()
30+
31+
accum_cer = []
32+
for image_path, label in tqdm(df):
33+
image = cv2.imread(image_path)
34+
35+
prediction_text = model.predict(image)
36+
37+
cer = get_cer(prediction_text, label)
38+
print(f"Image: {image_path}, Label: {label}, Prediction: {prediction_text}, CER: {cer}")
39+
40+
accum_cer.append(cer)
41+
42+
print(f"Average CER: {np.average(accum_cer)}")
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
def activation_layer(activation: str='relu', alpha: float=0.1, inplace: bool=True):
6+
""" Activation layer wrapper for LeakyReLU and ReLU activation functions
7+
8+
Args:
9+
activation: str, activation function name (default: 'relu')
10+
alpha: float (LeakyReLU activation function parameter)
11+
12+
Returns:
13+
torch.Tensor: activation layer
14+
"""
15+
if activation == 'relu':
16+
return nn.ReLU(inplace=inplace)
17+
18+
elif activation == 'leaky_relu':
19+
return nn.LeakyReLU(negative_slope=alpha, inplace=inplace)
20+
21+
class ConvBlock(nn.Module):
22+
""" Convolutional block with batch normalization
23+
"""
24+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int):
25+
super(ConvBlock, self).__init__()
26+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
27+
self.bn = nn.BatchNorm2d(out_channels)
28+
29+
def forward(self, x: torch.Tensor):
30+
return self.bn(self.conv(x))
31+
32+
33+
class ResidualBlock(nn.Module):
34+
def __init__(self, in_channels, out_channels, skip_conv=True, stride=1, dropout=0.2, activation='leaky_relu'):
35+
super(ResidualBlock, self).__init__()
36+
self.convb1 = ConvBlock(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
37+
self.act1 = activation_layer(activation)
38+
39+
self.convb2 = ConvBlock(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
40+
41+
self.dropout = nn.Dropout(p=dropout)
42+
43+
self.shortcut = None
44+
if skip_conv:
45+
if stride != 1 or in_channels != out_channels:
46+
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
47+
48+
self.act2 = activation_layer(activation)
49+
50+
def forward(self, x):
51+
skip = x
52+
53+
out = self.act1(self.convb1(x))
54+
out = self.convb2(out)
55+
56+
if self.shortcut is not None:
57+
out += self.shortcut(skip)
58+
59+
out = self.act2(out)
60+
out = self.dropout(out)
61+
62+
return out
63+
64+
class Network(nn.Module):
65+
""" Handwriting recognition network for CTC loss"""
66+
def __init__(self, num_chars: int, activation: str='leaky_relu', dropout: float=0.2):
67+
super(Network, self).__init__()
68+
69+
self.rb1 = ResidualBlock(3, 16, skip_conv = True, stride=1, activation=activation, dropout=dropout)
70+
self.rb2 = ResidualBlock(16, 16, skip_conv = True, stride=2, activation=activation, dropout=dropout)
71+
self.rb3 = ResidualBlock(16, 16, skip_conv = False, stride=1, activation=activation, dropout=dropout)
72+
73+
self.rb4 = ResidualBlock(16, 32, skip_conv = True, stride=2, activation=activation, dropout=dropout)
74+
self.rb5 = ResidualBlock(32, 32, skip_conv = False, stride=1, activation=activation, dropout=dropout)
75+
76+
self.rb6 = ResidualBlock(32, 64, skip_conv = True, stride=2, activation=activation, dropout=dropout)
77+
self.rb7 = ResidualBlock(64, 64, skip_conv = True, stride=1, activation=activation, dropout=dropout)
78+
79+
self.rb8 = ResidualBlock(64, 64, skip_conv = False, stride=1, activation=activation, dropout=dropout)
80+
self.rb9 = ResidualBlock(64, 64, skip_conv = False, stride=1, activation=activation, dropout=dropout)
81+
82+
self.lstm = nn.LSTM(64, 128, bidirectional=True, num_layers=1, batch_first=True)
83+
self.lstm_dropout = nn.Dropout(p=dropout)
84+
85+
self.output = nn.Linear(256, num_chars + 1)
86+
87+
def forward(self, images: torch.Tensor) -> torch.Tensor:
88+
# normalize images between 0 and 1
89+
images_flaot = images / 255.0
90+
91+
# transpose image to channel first
92+
images_flaot = images_flaot.permute(0, 3, 1, 2)
93+
94+
# apply convolutions
95+
x = self.rb1(images_flaot)
96+
x = self.rb2(x)
97+
x = self.rb3(x)
98+
x = self.rb4(x)
99+
x = self.rb5(x)
100+
x = self.rb6(x)
101+
x = self.rb7(x)
102+
x = self.rb8(x)
103+
x = self.rb9(x)
104+
105+
x = x.reshape(x.size(0), -1, x.size(1))
106+
107+
x, _ = self.lstm(x)
108+
x = self.lstm_dropout(x)
109+
110+
x = self.output(x)
111+
x = F.log_softmax(x, 2)
112+
113+
return x
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
torch==1.13.1
2+
tensorboard==2.10.1
3+
onnx==1.12.0
4+
torchsummaryX
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import os
2+
import tarfile
3+
from tqdm import tqdm
4+
from io import BytesIO
5+
from zipfile import ZipFile
6+
from urllib.request import urlopen
7+
8+
import torch
9+
import torch.optim as optim
10+
from torchsummaryX import summary
11+
12+
from mltu.torch.model import Model
13+
from mltu.torch.losses import CTCLoss
14+
from mltu.torch.dataProvider import DataProvider
15+
from mltu.torch.metrics import CERMetric, WERMetric
16+
from mltu.torch.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, Model2onnx, ReduceLROnPlateau
17+
18+
from mltu.preprocessors import ImageReader
19+
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding, ImageShowCV2
20+
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen
21+
22+
from model import Network
23+
from configs import ModelConfigs
24+
25+
def download_and_unzip(url, extract_to='Datasets', chunk_size=1024*1024):
26+
http_response = urlopen(url)
27+
28+
data = b''
29+
iterations = http_response.length // chunk_size + 1
30+
for _ in tqdm(range(iterations)):
31+
data += http_response.read(chunk_size)
32+
33+
zipfile = ZipFile(BytesIO(data))
34+
zipfile.extractall(path=extract_to)
35+
36+
dataset_path = os.path.join('Datasets', 'IAM_Words')
37+
if not os.path.exists(dataset_path):
38+
download_and_unzip('https://git.io/J0fjL', extract_to='Datasets')
39+
40+
file = tarfile.open(os.path.join(dataset_path, "words.tgz"))
41+
file.extractall(os.path.join(dataset_path, "words"))
42+
43+
dataset, vocab, max_len = [], set(), 0
44+
45+
# Preprocess the dataset by the specific IAM_Words dataset file structure
46+
words = open(os.path.join(dataset_path, "words.txt"), "r").readlines()
47+
for line in tqdm(words):
48+
if line.startswith("#"):
49+
continue
50+
51+
line_split = line.split(" ")
52+
if line_split[1] == "err":
53+
continue
54+
55+
folder1 = line_split[0][:3]
56+
folder2 = "-".join(line_split[0].split("-")[:2])
57+
file_name = line_split[0] + ".png"
58+
label = line_split[-1].rstrip('\n')
59+
60+
rel_path = os.path.join(dataset_path, "words", folder1, folder2, file_name)
61+
if not os.path.exists(rel_path):
62+
print(f"File not found: {rel_path}")
63+
continue
64+
65+
dataset.append([rel_path, label])
66+
vocab.update(list(label))
67+
max_len = max(max_len, len(label))
68+
69+
configs = ModelConfigs()
70+
71+
# Save vocab and maximum text length to configs
72+
configs.vocab = "".join(sorted(vocab))
73+
configs.max_text_length = max_len
74+
configs.save()
75+
76+
# Create a data provider for the dataset
77+
data_provider = DataProvider(
78+
dataset=dataset,
79+
skip_validation=True,
80+
batch_size=configs.batch_size,
81+
data_preprocessors=[ImageReader()],
82+
transformers=[
83+
# ImageShowCV2(), # uncomment to show images during training
84+
ImageResizer(configs.width, configs.height, keep_aspect_ratio=False),
85+
LabelIndexer(configs.vocab),
86+
LabelPadding(max_word_length=configs.max_text_length, padding_value=len(configs.vocab))
87+
],
88+
use_cache=True,
89+
)
90+
91+
# Split the dataset into training and validation sets
92+
train_dataProvider, test_dataProvider = data_provider.split(split = 0.9)
93+
94+
# Augment training data with random brightness, rotation and erode/dilate
95+
train_dataProvider.augmentors = [
96+
RandomBrightness(),
97+
RandomErodeDilate(),
98+
RandomSharpen(),
99+
RandomRotate(angle=10),
100+
]
101+
102+
network = Network(len(configs.vocab), activation='leaky_relu', dropout=0.3)
103+
loss = CTCLoss(blank=len(configs.vocab))
104+
optimizer = optim.Adam(network.parameters(), lr=configs.learning_rate)
105+
106+
# uncomment to print network summary, torchsummaryX package is required
107+
summary(network, torch.zeros((1, configs.height, configs.width, 3)))
108+
109+
# put on cuda device if available
110+
if torch.cuda.is_available():
111+
network = network.cuda()
112+
113+
# create callbacks
114+
earlyStopping = EarlyStopping(monitor='val_CER', patience=20, mode="min", verbose=1)
115+
modelCheckpoint = ModelCheckpoint(configs.model_path + '/model.pt', monitor='val_CER', mode="min", save_best_only=True, verbose=1)
116+
tb_callback = TensorBoard(configs.model_path + '/logs')
117+
reduce_lr = ReduceLROnPlateau(monitor='val_CER', factor=0.9, patience=10, verbose=1, mode='min', min_lr=1e-6)
118+
model2onnx = Model2onnx(
119+
saved_model_path=configs.model_path + '/model.pt',
120+
input_shape=(1, configs.height, configs.width, 3),
121+
verbose=1,
122+
metadata={"vocab": configs.vocab}
123+
)
124+
125+
# create model object that will handle training and testing of the network
126+
model = Model(network, optimizer, loss, metrics=[CERMetric(configs.vocab), WERMetric(configs.vocab)])
127+
model.fit(
128+
train_dataProvider,
129+
test_dataProvider,
130+
epochs=1000,
131+
callbacks=[earlyStopping, modelCheckpoint, tb_callback, reduce_lr, model2onnx]
132+
)
133+
134+
# Save training and validation datasets as csv files
135+
train_dataProvider.to_csv(os.path.join(configs.model_path, 'train.csv'))
136+
test_dataProvider.to_csv(os.path.join(configs.model_path, 'val.csv'))

0 commit comments

Comments
 (0)