Skip to content

Commit 4fcce05

Browse files
committed
Fixed minor issue it's ready to use
1 parent 54bfbed commit 4fcce05

File tree

1 file changed

+10
-14
lines changed
  • Chinese_plate_scan/License-plate-recognition

1 file changed

+10
-14
lines changed

Chinese_plate_scan/License-plate-recognition/Unet.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def unet_train():
1515
print(n)
1616
X_train, y_train = [], []
1717
for i in range(n):
18-
print("正在读取第%d张图片" % i)
1918
img = cv2.imread(path + 'train_image/%d.png' % i)
2019
label = cv2.imread(path + 'train_label/%d.png' % i)
2120
X_train.append(img)
@@ -91,24 +90,21 @@ def Conv2dT_BN(x, filters, kernel_size, strides=(2, 2), padding='same'):
9190
model.summary()
9291

9392
print("开始训练u-net")
94-
model.fit(X_train, y_train, epochs=100, batch_size=15)#epochs和batch_size看个人情况调整,batch_size不要过大,否则内存容易溢出
95-
#我11G显存也只能设置15-20左右,我训练最终loss降低至250左右,acc约95%左右
93+
model.fit(X_train, y_train, epochs=100, batch_size=15)
9694
model.save('unet.h5')
9795
print('unet.h5保存成功!!!')
9896

9997

10098
def unet_predict(unet, img_src_path):
101-
img_src = cv2.imdecode(np.fromfile(img_src_path, dtype=np.uint8), -1) # 从中文路径读取时用
102-
# img_src=cv2.imread(img_src_path)
99+
img_src = cv2.imdecode(np.fromfile(img_src_path, dtype=np.uint8), -1)
103100
if img_src.shape != (512, 512, 3):
104-
img_src = cv2.resize(img_src, dsize=(512, 512), interpolation=cv2.INTER_AREA)[:, :, :3] # dsize=(宽度,高度),[:,:,:3]是防止图片为4通道图片,后续无法reshape
105-
img_src = img_src.reshape(1, 512, 512, 3) # 预测图片shape为(1,512,512,3)
106-
107-
img_mask = unet.predict(img_src) # 归一化除以255后进行预测
108-
img_src = img_src.reshape(512, 512, 3) # 将原图reshape为3维
109-
img_mask = img_mask.reshape(512, 512, 3) # 将预测后图片reshape为3维
110-
img_mask = img_mask / np.max(img_mask) * 255 # 归一化后乘以255
111-
img_mask[:, :, 2] = img_mask[:, :, 1] = img_mask[:, :, 0] # 三个通道保持相同
112-
img_mask = img_mask.astype(np.uint8) # 将img_mask类型转为int型
101+
img_src = cv2.resize(img_src, dsize=(512, 512), interpolation=cv2.INTER_AREA)[:, :, :3]
102+
img_src = img_src.reshape(1, 512, 512, 3)
103+
img_mask = unet.predict(img_src)
104+
img_src = img_src.reshape(512, 512, 3)
105+
img_mask = img_mask.reshape(512, 512, 3)
106+
img_mask = img_mask / np.max(img_mask) * 255
107+
img_mask[:, :, 2] = img_mask[:, :, 1] = img_mask[:, :, 0]
108+
img_mask = img_mask.astype(np.uint8)
113109

114110
return img_src, img_mask

0 commit comments

Comments
 (0)