-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
executable file
·35 lines (28 loc) · 1.02 KB
/
predict.py
File metadata and controls
executable file
·35 lines (28 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import tensorflow as tf
from PIL import Image
import numpy as np
from train import CNN
class Predict(object):
def __init__(self):
latest = tf.train.latest_checkpoint('./ckpt')
self.cnn = CNN()
# 恢复网络权重
self.cnn.model.load_weights(latest)
def predict(self, image_path):
# 以黑白方式读取图片
img = Image.open(image_path).convert('L')
img = np.reshape(img, (28, 28, 1)) / 255.
x = np.array([1 - img])
# API refer: https://keras.io/models/model/
y = self.cnn.model.predict(x)
# 因为x只传入了一张图片,取y[0]即可
# np.argmax()取得最大值的下标,即代表的数字
print(image_path)
print(y[0])
print(' -> Predict digit', np.argmax(y[0]))
if __name__ == "__main__":
app = Predict()
app.predict('mnist/test_images/0_0.png')
app.predict('mnist/test_images/1_63.png')
app.predict('mnist/test_images/3_195.png')
app.predict('mnist/test_images/4_66.png')