@@ -15,7 +15,6 @@ def unet_train():
15
15
print (n )
16
16
X_train , y_train = [], []
17
17
for i in range (n ):
18
- print ("正在读取第%d张图片" % i )
19
18
img = cv2 .imread (path + 'train_image/%d.png' % i )
20
19
label = cv2 .imread (path + 'train_label/%d.png' % i )
21
20
X_train .append (img )
@@ -91,24 +90,21 @@ def Conv2dT_BN(x, filters, kernel_size, strides=(2, 2), padding='same'):
91
90
model .summary ()
92
91
93
92
print ("开始训练u-net" )
94
- model .fit (X_train , y_train , epochs = 100 , batch_size = 15 )#epochs和batch_size看个人情况调整,batch_size不要过大,否则内存容易溢出
95
- #我11G显存也只能设置15-20左右,我训练最终loss降低至250左右,acc约95%左右
93
+ model .fit (X_train , y_train , epochs = 100 , batch_size = 15 )
96
94
model .save ('unet.h5' )
97
95
print ('unet.h5保存成功!!!' )
98
96
99
97
100
98
def unet_predict (unet , img_src_path ):
101
- img_src = cv2 .imdecode (np .fromfile (img_src_path , dtype = np .uint8 ), - 1 ) # 从中文路径读取时用
102
- # img_src=cv2.imread(img_src_path)
99
+ img_src = cv2 .imdecode (np .fromfile (img_src_path , dtype = np .uint8 ), - 1 )
103
100
if img_src .shape != (512 , 512 , 3 ):
104
- img_src = cv2 .resize (img_src , dsize = (512 , 512 ), interpolation = cv2 .INTER_AREA )[:, :, :3 ] # dsize=(宽度,高度),[:,:,:3]是防止图片为4通道图片,后续无法reshape
105
- img_src = img_src .reshape (1 , 512 , 512 , 3 ) # 预测图片shape为(1,512,512,3)
106
-
107
- img_mask = unet .predict (img_src ) # 归一化除以255后进行预测
108
- img_src = img_src .reshape (512 , 512 , 3 ) # 将原图reshape为3维
109
- img_mask = img_mask .reshape (512 , 512 , 3 ) # 将预测后图片reshape为3维
110
- img_mask = img_mask / np .max (img_mask ) * 255 # 归一化后乘以255
111
- img_mask [:, :, 2 ] = img_mask [:, :, 1 ] = img_mask [:, :, 0 ] # 三个通道保持相同
112
- img_mask = img_mask .astype (np .uint8 ) # 将img_mask类型转为int型
101
+ img_src = cv2 .resize (img_src , dsize = (512 , 512 ), interpolation = cv2 .INTER_AREA )[:, :, :3 ]
102
+ img_src = img_src .reshape (1 , 512 , 512 , 3 )
103
+ img_mask = unet .predict (img_src )
104
+ img_src = img_src .reshape (512 , 512 , 3 )
105
+ img_mask = img_mask .reshape (512 , 512 , 3 )
106
+ img_mask = img_mask / np .max (img_mask ) * 255
107
+ img_mask [:, :, 2 ] = img_mask [:, :, 1 ] = img_mask [:, :, 0 ]
108
+ img_mask = img_mask .astype (np .uint8 )
113
109
114
110
return img_src , img_mask
0 commit comments