Skip to content

Commit b97d8b8

Browse files
committed
feat: Add validation (#86)
1 parent 46a2d56 commit b97d8b8

File tree

1 file changed

+169
-0
lines changed

1 file changed

+169
-0
lines changed

validation.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import argparse
2+
import collections
3+
import json
4+
import pickle
5+
import os
6+
import random
7+
import torch
8+
import numpy as np
9+
import pandas as pd
10+
import torch.nn.functional as F
11+
import torch.utils.data as module_data
12+
import data_loader as module_dataset
13+
import model as module_arch
14+
import albumentations as A
15+
from utils import IND2CLASS, encode_mask_to_rle, CLASSES
16+
from parse_config import ConfigParser
17+
from tqdm import tqdm
18+
from sklearn.model_selection import GroupKFold
19+
20+
21+
def dice_coef(outputs, masks):
22+
y_true_f = masks.flatten(2)
23+
y_pred_f = outputs.flatten(2)
24+
intersection = torch.sum(y_true_f * y_pred_f, -1)
25+
eps = 0.0001
26+
return (2.0 * intersection + eps) / (
27+
torch.sum(y_true_f, -1) + torch.sum(y_pred_f, -1) + eps
28+
)
29+
30+
31+
def set_seeds(seed=42):
32+
torch.manual_seed(seed)
33+
torch.cuda.manual_seed(seed)
34+
torch.cuda.manual_seed_all(seed)
35+
torch.backends.cudnn.deterministic = True
36+
torch.backends.cudnn.benchmark = False
37+
np.random.seed(seed)
38+
random.seed(seed)
39+
os.environ["PYTHONHASHSEED"] = str(seed)
40+
41+
42+
def main(config):
43+
set_seeds()
44+
molel_name = config["path"]["model_path"].split('/')[-2]
45+
save_csv_path = config["path"]["save_csv_path"]
46+
thresholds = config["thresholds"]
47+
cfg_path = config["path"]
48+
with open(cfg_path["image_name_pickle_path"], "rb") as f:
49+
filenames = np.array(pickle.load(f))
50+
with open(cfg_path["label_name_pickle_path"], "rb") as f:
51+
labelnames = np.array(pickle.load(f))
52+
with open(cfg_path["image_dict_pickle_path"], "rb") as f:
53+
hash_dict = pickle.load(f)
54+
55+
valid_tf_list = []
56+
for tf in config["valid_transforms"]:
57+
valid_tf_list.append(
58+
getattr(A, tf["name"])(*tf["args"], **tf["kwargs"])
59+
)
60+
61+
# group k-fold
62+
groups = [os.path.dirname(fname) for fname in filenames]
63+
ys = [0 for _ in filenames]
64+
gkf = GroupKFold(n_splits=config["kfold"]["n_splits"])
65+
for fold, (x, y) in enumerate(gkf.split(filenames, ys, groups), start=1):
66+
if fold != config["kfold"]["fold"]: continue
67+
valid_filenames = list(filenames[y])
68+
valid_labelnames = list(labelnames[y])
69+
valid_dataset = config.init_obj(
70+
"valid_dataset",
71+
module_dataset,
72+
filenames=valid_filenames,
73+
labelnames=valid_labelnames,
74+
hash_dict=hash_dict,
75+
mmap_path=cfg_path["mmap_path"],
76+
label_root=cfg_path["label_path"],
77+
transforms=valid_tf_list,
78+
)
79+
80+
valid_data_loader = config.init_obj(
81+
"valid_data_loader", module_data, valid_dataset
82+
)
83+
84+
# build model architecture
85+
model = config.init_obj("arch", module_arch)
86+
if config["n_gpu"] > 1:
87+
model = torch.nn.DataParallel(model)
88+
model.load_state_dict(
89+
torch.load(config["path"]["model_path"])["state_dict"]
90+
)
91+
92+
# prepare model for testing
93+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94+
model = model.to(device)
95+
model.eval()
96+
97+
with torch.no_grad():
98+
for threshold in thresholds:
99+
dices = []
100+
rles = []
101+
filename_and_class = []
102+
for step, (images, masks, image_names) in tqdm(enumerate(valid_data_loader), total=len(valid_data_loader)):
103+
images, masks = images.cuda(), masks.cuda()
104+
outputs = model(images)
105+
106+
outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear")
107+
108+
outputs = torch.sigmoid(outputs)
109+
outputs = (outputs > threshold).detach().cpu()
110+
masks = masks.detach().cpu()
111+
112+
dice = dice_coef(outputs, masks)
113+
dices.append(dice)
114+
115+
for output, image_name in zip(outputs, image_names):
116+
for c, segm in enumerate(output):
117+
rle = encode_mask_to_rle(segm)
118+
rles.append(rle)
119+
filename_and_class.append(f"{IND2CLASS[c]}_{image_name.replace('_','-')}")
120+
121+
dices = torch.cat(dices, 0)
122+
dices_per_class = torch.mean(dices, 0)
123+
dice_str = [
124+
f"{d.item():.4f}"
125+
for c, d in zip(CLASSES, dices_per_class)
126+
]
127+
dice_str = "\n".join(dice_str)
128+
avg_dice = torch.mean(dices_per_class).item()
129+
print(dice_str)
130+
print(f'{avg_dice:.4f}')
131+
132+
classes, filename = zip(*[x.split("_") for x in filename_and_class])
133+
image_name = [os.path.basename(f) for f in filename]
134+
df = pd.DataFrame(
135+
{
136+
"image_name": image_name,
137+
"class": classes,
138+
"rle": rles,
139+
}
140+
)
141+
df.to_csv(f'{save_csv_path}/{molel_name}_{threshold}.csv', index=False)
142+
143+
144+
if __name__ == "__main__":
145+
args = argparse.ArgumentParser(description="PyTorch Template")
146+
args.add_argument(
147+
"-c",
148+
"--config",
149+
default="/data/ephemeral/home/level2-cv-semanticsegmentation-cv-03/config_inference.json",
150+
type=str,
151+
help="config file path (default: None)",
152+
)
153+
args.add_argument(
154+
"-r",
155+
"--resume",
156+
default=None,
157+
type=str,
158+
help="path to latest checkpoint (default: None)",
159+
)
160+
args.add_argument(
161+
"-d",
162+
"--device",
163+
default=None,
164+
type=str,
165+
help="indices of GPUs to enable (default: all)",
166+
)
167+
CustomArgs = collections.namedtuple("CustomArgs", "flags type target")
168+
config = ConfigParser.from_args(args, mode="inference")
169+
main(config)

0 commit comments

Comments
 (0)