-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprepare_train_val.py
More file actions
54 lines (47 loc) · 2.13 KB
/
prepare_train_val.py
File metadata and controls
54 lines (47 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from scipy import misc
import numpy as np
files = []
f = open("dict_file.txt","r")
for line in f:
i = line[:-1].split(",")
i[0] = int(i[0])
files.append(i)
road_color = [128,64,128,255]
def match_grid(mask,c):
return np.array((mask[:,:,0]==c[0])*(mask[:,:,1]==c[1])*(mask[:,:,2]==c[2]),dtype=np.float32)
train_images_f = open("train_images.txt","w")
train_labels_f = open("train_labels.txt","w")
val_images_f = open("val_images.txt","w")
val_labels_f = open("val_labels.txt","w")
c = 0
for root,_,paths in os.walk("syn_data/images"):
for p in paths:
print(c)
c += 1
full_path = root+"/"+p
im = misc.imread(full_path)
a = p.split("_")
a = int(a[0])
cityscapes_path = files[a]
cityscapes_mask = misc.imread("cityscapes/gtFine/"+cityscapes_path[1]+"/"+cityscapes_path[2]+"_gtFine_color.png")
cityscapes_im = misc.imread("cityscapes/leftImg8bit/"+cityscapes_path[1]+"/"+cityscapes_path[2]+"_leftImg8bit.png")
# eq = np.array(cityscapes_im==im,dtype=np.float32)
eq = np.array((cityscapes_im[:,:,0]==im[:,:,0])*(cityscapes_im[:,:,1]==im[:,:,1])*(cityscapes_im[:,:,2]==im[:,:,2]),dtype=np.float32)
road = match_grid(cityscapes_mask,road_color)
# binary_mask = np.where(eq[:,:,0]==0,road,255.0)
binary_mask = road
binary_mask[0][0] = 255
cutout_im = im*np.reshape(eq,[1024,2048,1])
if "train" in full_path:
misc.imsave("syn_data/binary_labels/train/" + p,binary_mask)
misc.imsave("syn_data/cutouts/train/" + p,cutout_im)
train_labels_f.write("syn_data/binary_labels/train/"+p+"\n")
train_images_f.write("syn_data/cutouts/train/"+p+"\n")
misc.imsave("syn_data/paste_mask/train/"+p,eq)
else:
misc.imsave("syn_data/binary_labels/val/" + p,binary_mask)
misc.imsave("syn_data/cutouts/val/" + p,cutout_im)
val_labels_f.write("syn_data/binary_labels/val/"+p+"\n")
val_images_f.write("syn_data/cutouts/val/"+p+"\n")
misc.imsave("syn_data/paste_mask/val/"+p,eq)