-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathtest_model.py
More file actions
62 lines (47 loc) · 1.58 KB
/
test_model.py
File metadata and controls
62 lines (47 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# encoding=utf-8
import os
import re
import time
import math
import tensorflow as tf
import utils
class MyClassback(tf.keras.callbacks.Callback):
def __init__(self, model_prefix: str, save_dir: str):
self.prefix = model_prefix
self.save_dir = save_dir
def on_epoch_end(self, epoch, logs={}):
acc = logs.get('acc')
if epoch % 5 == 0:
self.save_model_weights(epoch, acc)
if acc >= 0.99:
self.model.stop_training = True
self.save_model_weights(epoch, acc)
def save_model_weights(self, epoch, acc):
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
model_weights_filename = f'{self.save_dir}/{self.prefix}_{utilities.time_now()}_epoch_{str(epoch).zfill(2)}_acc_{acc:.2f}.h5'
self.model.save_weights(model_weights_filename)
if __name__ == "__main__":
model = utilities.kanji_model_v3()
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
)
model.summary()
dataset_filename = 'kanji_dataset.tfrecord'
buffer_size = 256776
batch_size = 256
steps_per_epoch = math.ceil(buffer_size / batch_size)
ds = utilities.load_tfrecord(dataset_filename)
ds.cache()
ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=buffer_size))
ds.batch(batch_size)
print(ds)
# callback = MyClassback('kanji_model_v3', 'kanji_model_v3')
# model.fit(
# ds,
# steps_per_epoch=steps_per_epoch,
# epochs=5,
# callbacks=[callback],
# )