|
4 | 4 |
|
5 | 5 | from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard |
6 | 6 |
|
7 | | -from mltu.dataProvider import DataProvider |
| 7 | +from mltu.tensorflow.dataProvider import DataProvider |
| 8 | +from mltu.tensorflow.losses import CTCloss |
| 9 | +from mltu.tensorflow.callbacks import Model2onnx, TrainLogger |
| 10 | +from mltu.tensorflow.metrics import CWERMetric |
| 11 | + |
8 | 12 | from mltu.preprocessors import ImageReader |
9 | 13 | from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding |
10 | 14 | from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate |
11 | | -from mltu.losses import CTCloss |
12 | | -from mltu.callbacks import Model2onnx, TrainLogger |
13 | | -from mltu.metrics import CWERMetric |
14 | 15 |
|
15 | 16 | from model import train_model |
16 | 17 | from configs import ModelConfigs |
@@ -70,15 +71,15 @@ def download_and_unzip(url, extract_to='Datasets'): |
70 | 71 | model.compile( |
71 | 72 | optimizer=tf.keras.optimizers.Adam(learning_rate=configs.learning_rate), |
72 | 73 | loss=CTCloss(), |
73 | | - metrics=[CWERMetric()], |
| 74 | + metrics=[CWERMetric(padding_token=len(configs.vocab))], |
74 | 75 | run_eagerly=False |
75 | 76 | ) |
76 | 77 | model.summary(line_length=110) |
77 | 78 | # Define path to save the model |
78 | 79 | stow.mkdir(configs.model_path) |
79 | 80 |
|
80 | 81 | # Define callbacks |
81 | | -earlystopper = EarlyStopping(monitor='val_CER', patience=40, verbose=1) |
| 82 | +earlystopper = EarlyStopping(monitor='val_CER', patience=50, verbose=1) |
82 | 83 | checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor='val_CER', verbose=1, save_best_only=True, mode='min') |
83 | 84 | trainLogger = TrainLogger(configs.model_path) |
84 | 85 | tb_callback = TensorBoard(f'{configs.model_path}/logs', update_freq=1) |
|
0 commit comments