From 506c4d521224a50bd0308338145143d2a8934bad Mon Sep 17 00:00:00 2001 From: bala93 Date: Fri, 14 Oct 2022 14:28:38 -0400 Subject: [PATCH] A demo code for test voc12 dataset --- demo_testvoc.py | 193 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 demo_testvoc.py diff --git a/demo_testvoc.py b/demo_testvoc.py new file mode 100644 index 0000000..ebf834d --- /dev/null +++ b/demo_testvoc.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python +# coding: utf-8 +# +# Author: Kazuto Nakashima +# URL: https://kazuto1011.github.io +# Date: 07 January 2019 + +from __future__ import absolute_import, division, print_function +from configparser import Interpolation + +import click +import cv2 +import matplotlib +import matplotlib.cm as cm +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import OmegaConf + +from libs.models import * +from libs.utils import DenseCRF + +import os +import pandas as pd +from tqdm import tqdm + +def get_device(cuda): + cuda = cuda and torch.cuda.is_available() + device = torch.device("cuda" if cuda else "cpu") + if cuda: + current_device = torch.cuda.current_device() + print("Device:", torch.cuda.get_device_name(current_device)) + else: + print("Device: CPU") + return device + + +def get_classtable(CONFIG): + with open(CONFIG.DATASET.LABELS) as f: + classes = {} + for label in f: + label = label.rstrip().split("\t") + classes[int(label[0])] = label[1].split(",")[0] + return classes + + +def setup_postprocessor(CONFIG): + # CRF post-processor + postprocessor = DenseCRF( + iter_max=CONFIG.CRF.ITER_MAX, + pos_xy_std=CONFIG.CRF.POS_XY_STD, + pos_w=CONFIG.CRF.POS_W, + bi_xy_std=CONFIG.CRF.BI_XY_STD, + bi_rgb_std=CONFIG.CRF.BI_RGB_STD, + bi_w=CONFIG.CRF.BI_W, + ) + return postprocessor + + +def preprocessing(image, device, CONFIG): + # Resize + scale = CONFIG.IMAGE.SIZE.TEST / max(image.shape[:2]) + image = cv2.resize(image, dsize=None, fx=scale, fy=scale) + raw_image = image.astype(np.uint8) + + # Subtract mean values + image = image.astype(np.float32) + image -= np.array( + [ + float(CONFIG.IMAGE.MEAN.B), + float(CONFIG.IMAGE.MEAN.G), + float(CONFIG.IMAGE.MEAN.R), + ] + ) + + # Convert to torch.Tensor and add "batch" axis + image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0) + image = image.to(device) + + return image, raw_image, scale + + +def inference(model, image, raw_image=None, postprocessor=None): + _, _, H, W = image.shape + + # Image -> Probability map + logits = model(image) + logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False) + probs = F.softmax(logits, dim=1)[0] + probs = probs.cpu().numpy() + + # Refine the prob map with CRF + if postprocessor and raw_image is not None: + probs = postprocessor(raw_image, probs) + + labelmap = np.argmax(probs, axis=0) + + return labelmap + + +@click.group() +@click.pass_context +def main(ctx): + """ + Demo with a trained model + """ + + print("Mode:", ctx.invoked_subcommand) + +@main.command() +@click.option( + "-c", + "--config-path", + type=click.File(), + required=True, + help="Dataset configuration file in YAML", +) +@click.option( + "-m", + "--model-path", + type=click.Path(exists=True), + required=True, + help="PyTorch model to be loaded", +) +@click.option( + "-f", + "--txt-file", + type=click.Path(exists=True), + required=True, + help="CSV file to be processed", +) +@click.option( + "-i", + "--img-dir", + type=click.Path(exists=True), + required=True, + help="Image to be processed", +) + +@click.option( + "-s", + "--save-dir", + type=click.Path(exists=True), + required=True, + help="Directory to be saved", +) + +@click.option( + "--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]" +) +@click.option("--crf", is_flag=True, show_default=True, help="CRF post-processing") +def single(config_path, model_path, img_dir, txt_file, save_dir, cuda, crf): + """ + Inference from a single image + """ + + # Setup + print (config_path) + CONFIG = OmegaConf.load(config_path) + device = get_device(cuda) + torch.set_grad_enabled(False) + + classes = get_classtable(CONFIG) + postprocessor = setup_postprocessor(CONFIG) if crf else None + + model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES) + state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(state_dict) + model.eval() + model.to(device) + + print("Model:", CONFIG.MODEL.NAME) + + testnames = pd.read_csv(txt_file,header=None)[0].to_list() + + for testname in tqdm(testnames): + + img_path = os.path.join(img_dir,testname.split('/')[2]) + save_path = os.path.join(save_dir,testname.split('/')[2].replace('jpg','png')) + + image = cv2.imread(img_path, cv2.IMREAD_COLOR) + + image, raw_image, scale = preprocessing(image, device, CONFIG) + labelmap = inference(model, image, raw_image, postprocessor) + + output = cv2.resize(labelmap, dsize=None, fx= 1/scale, fy=1/scale, interpolation=cv2.INTER_NEAREST) + + cv2.imwrite(save_path, output) + +if __name__ == "__main__": + main()