Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions demo_testvoc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#!/usr/bin/env python
# coding: utf-8
#
# Author: Kazuto Nakashima
# URL: https://kazuto1011.github.io
# Date: 07 January 2019

from __future__ import absolute_import, division, print_function
from configparser import Interpolation

import click
import cv2
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf

from libs.models import *
from libs.utils import DenseCRF

import os
import pandas as pd
from tqdm import tqdm

def get_device(cuda):
cuda = cuda and torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
if cuda:
current_device = torch.cuda.current_device()
print("Device:", torch.cuda.get_device_name(current_device))
else:
print("Device: CPU")
return device


def get_classtable(CONFIG):
with open(CONFIG.DATASET.LABELS) as f:
classes = {}
for label in f:
label = label.rstrip().split("\t")
classes[int(label[0])] = label[1].split(",")[0]
return classes


def setup_postprocessor(CONFIG):
# CRF post-processor
postprocessor = DenseCRF(
iter_max=CONFIG.CRF.ITER_MAX,
pos_xy_std=CONFIG.CRF.POS_XY_STD,
pos_w=CONFIG.CRF.POS_W,
bi_xy_std=CONFIG.CRF.BI_XY_STD,
bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
bi_w=CONFIG.CRF.BI_W,
)
return postprocessor


def preprocessing(image, device, CONFIG):
# Resize
scale = CONFIG.IMAGE.SIZE.TEST / max(image.shape[:2])
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
raw_image = image.astype(np.uint8)

# Subtract mean values
image = image.astype(np.float32)
image -= np.array(
[
float(CONFIG.IMAGE.MEAN.B),
float(CONFIG.IMAGE.MEAN.G),
float(CONFIG.IMAGE.MEAN.R),
]
)

# Convert to torch.Tensor and add "batch" axis
image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
image = image.to(device)

return image, raw_image, scale


def inference(model, image, raw_image=None, postprocessor=None):
_, _, H, W = image.shape

# Image -> Probability map
logits = model(image)
logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
probs = F.softmax(logits, dim=1)[0]
probs = probs.cpu().numpy()

# Refine the prob map with CRF
if postprocessor and raw_image is not None:
probs = postprocessor(raw_image, probs)

labelmap = np.argmax(probs, axis=0)

return labelmap


@click.group()
@click.pass_context
def main(ctx):
"""
Demo with a trained model
"""

print("Mode:", ctx.invoked_subcommand)

@main.command()
@click.option(
"-c",
"--config-path",
type=click.File(),
required=True,
help="Dataset configuration file in YAML",
)
@click.option(
"-m",
"--model-path",
type=click.Path(exists=True),
required=True,
help="PyTorch model to be loaded",
)
@click.option(
"-f",
"--txt-file",
type=click.Path(exists=True),
required=True,
help="CSV file to be processed",
)
@click.option(
"-i",
"--img-dir",
type=click.Path(exists=True),
required=True,
help="Image to be processed",
)

@click.option(
"-s",
"--save-dir",
type=click.Path(exists=True),
required=True,
help="Directory to be saved",
)

@click.option(
"--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]"
)
@click.option("--crf", is_flag=True, show_default=True, help="CRF post-processing")
def single(config_path, model_path, img_dir, txt_file, save_dir, cuda, crf):
"""
Inference from a single image
"""

# Setup
print (config_path)
CONFIG = OmegaConf.load(config_path)
device = get_device(cuda)
torch.set_grad_enabled(False)

classes = get_classtable(CONFIG)
postprocessor = setup_postprocessor(CONFIG) if crf else None

model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict)
model.eval()
model.to(device)

print("Model:", CONFIG.MODEL.NAME)

testnames = pd.read_csv(txt_file,header=None)[0].to_list()

for testname in tqdm(testnames):

img_path = os.path.join(img_dir,testname.split('/')[2])
save_path = os.path.join(save_dir,testname.split('/')[2].replace('jpg','png'))

image = cv2.imread(img_path, cv2.IMREAD_COLOR)

image, raw_image, scale = preprocessing(image, device, CONFIG)
labelmap = inference(model, image, raw_image, postprocessor)

output = cv2.resize(labelmap, dsize=None, fx= 1/scale, fy=1/scale, interpolation=cv2.INTER_NEAREST)

cv2.imwrite(save_path, output)

if __name__ == "__main__":
main()