Skip to content

Commit d1b632a

Browse files
authored
[feat] visualization support for NN part 2 (#16)
This is the implementation of visualization support for NN Part 2. The functions are adapted from Detectron2 and JDE. To enable the function, simply add ++codec.save_visualization=True to the evaluation scripts. To disable it, just keep the scripts unchanged.
1 parent 53fbf1c commit d1b632a

File tree

10 files changed

+154
-7
lines changed

10 files changed

+154
-7
lines changed

cfgs/pipeline/remote_inference.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,7 @@ evaluation:
3333
bypass: False
3434
dump: True
3535
evaluation_dir: "${codec.output_dir}/evaluation"
36+
visualization:
37+
save_visualization: "${codec.save_visualization}"
38+
visualization_dir: "${codec.output_dir}/visualization"
39+
threshold: 0 # only for detectron2, 0 means default setting of detectron2

cfgs/pipeline/split_inference.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,7 @@ evaluation:
4444
bypass: False
4545
dump: True
4646
evaluation_dir: "${codec.output_dir}/evaluation"
47+
visualization:
48+
save_visualization: "${codec.save_visualization}"
49+
visualization_dir: "${codec.output_dir}/visualization"
50+
threshold: 0 # only for detectron2, 0 means default setting of detectron2

compressai_vision/evaluators/evaluators.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,19 @@
2929

3030
import json
3131
import math
32+
import os
3233
from collections import defaultdict
3334
from pathlib import Path
3435
from typing import Optional
3536

37+
import cv2
3638
import motmetrics as mm
3739
import numpy as np
3840
import pandas as pd
3941
import torch
42+
from detectron2.data import MetadataCatalog
4043
from detectron2.evaluation import COCOEvaluator
44+
from detectron2.utils.visualizer import Visualizer
4145
from jde.utils.io import unzip_objs
4246
from mmpose.datasets.datasets import BaseCocoStyleDataset
4347
from mmpose.datasets.transforms import PackPoseInputs
@@ -95,6 +99,29 @@ def reset(self):
9599
def digest(self, gt, pred):
96100
return self._evaluator.process(gt, pred)
97101

102+
def save_visualization(self, gt, pred, output_dir, threshold):
103+
gt_image = gt[0]["image"]
104+
if torch.is_floating_point(gt_image):
105+
gt_image = (gt_image * 255).clamp(0, 255).to(torch.uint8)
106+
gt_image = gt_image[[2, 1, 0], ...]
107+
gt_image = gt_image.permute(1, 2, 0).cpu().numpy()
108+
gt_image = cv2.resize(gt_image, (gt[0]["width"], gt[0]["height"]))
109+
110+
img_id = gt[0]["image_id"]
111+
metadata = MetadataCatalog.get(self.dataset_name)
112+
instances = pred[0]["instances"].to("cpu")
113+
if threshold:
114+
keep = instances.scores >= threshold
115+
instances = instances[keep]
116+
117+
v = Visualizer(gt_image[:, :, ::-1], metadata, scale=1)
118+
out = v.draw_instance_predictions(
119+
instances
120+
) # selected_instances for specific class
121+
output_path = os.path.join(output_dir, f"{img_id}.jpg")
122+
cv2.imwrite(output_path, out.get_image()[:, :, ::-1])
123+
return
124+
98125
def results(self, save_path: str = None):
99126
out = self._evaluator.evaluate()
100127

@@ -259,6 +286,30 @@ def digest(self, gt, pred):
259286

260287
return
261288

289+
def save_visualization(self, gt, pred, output_dir, threshold):
290+
gt_image = gt[0]["image"]
291+
if torch.is_floating_point(gt_image):
292+
gt_image = (gt_image * 255).clamp(0, 255).to(torch.uint8)
293+
gt_image = gt_image[[2, 1, 0], ...]
294+
gt_image = gt_image.permute(1, 2, 0).cpu().numpy()
295+
gt_image = cv2.resize(gt_image, (gt[0]["width"], gt[0]["height"]))
296+
297+
img_id = gt[0]["image_id"]
298+
metadata = MetadataCatalog.get(self.dataset_name)
299+
instances = pred[0]["instances"].to("cpu")
300+
301+
if threshold:
302+
keep = instances.scores >= threshold
303+
instances = instances[keep]
304+
305+
v = Visualizer(gt_image[:, :, ::-1], metadata, scale=1)
306+
out = v.draw_instance_predictions(
307+
instances
308+
) # selected_instances for specific class
309+
output_path = os.path.join(output_dir, f"{img_id}.jpg")
310+
cv2.imwrite(output_path, out.get_image()[:, :, ::-1])
311+
return
312+
262313
def _process_prediction(self, pred_dict):
263314
valid_cls = []
264315
valid_scores = []
@@ -425,6 +476,57 @@ def digest(self, gt, pred):
425476

426477
self._predictions[int(gt[0]["image_id"])] = pred_list
427478

479+
def save_visualization(self, gt, pred, output_dir, threshold):
480+
image_id = gt[0]["image_id"]
481+
gt_image = gt[0]["image"].permute(1, 2, 0).cpu().numpy()
482+
gt_image = (gt_image * 255).astype(np.uint8)
483+
gt_image = cv2.resize(
484+
cv2.cvtColor(gt_image, cv2.COLOR_RGB2BGR), (gt[0]["width"], gt[0]["height"])
485+
)
486+
online_im = self.plot_tracking(
487+
gt_image, pred["tlwhs"], pred["ids"], frame_id=image_id
488+
)
489+
output_path = os.path.join(output_dir, f"{image_id}.png")
490+
cv2.imwrite(output_path, online_im)
491+
return
492+
493+
def plot_tracking(
494+
self, image, tlwhs, obj_ids, scores=None, frame_id=0, fps=0.0, ids2=None
495+
):
496+
im = np.ascontiguousarray(np.copy(image))
497+
im_h, im_w = im.shape[:2]
498+
499+
text_scale = max(1, image.shape[1] / 1600.0)
500+
text_thickness = 1 if text_scale > 1.1 else 1
501+
line_thickness = max(1, int(image.shape[1] / 500.0))
502+
503+
for i, tlwh in enumerate(tlwhs):
504+
x1, y1, w, h = tlwh
505+
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
506+
obj_id = int(obj_ids[i])
507+
id_text = "{}".format(int(obj_id))
508+
if ids2 is not None:
509+
id_text = id_text + ", {}".format(int(ids2[i]))
510+
color = self.get_color(abs(obj_id))
511+
cv2.rectangle(
512+
im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness
513+
)
514+
cv2.putText(
515+
im,
516+
id_text,
517+
(intbox[0], intbox[1] + 30),
518+
cv2.FONT_HERSHEY_PLAIN,
519+
text_scale,
520+
(0, 0, 255),
521+
thickness=text_thickness,
522+
)
523+
return im
524+
525+
def get_color(self, idx):
526+
idx = idx * 3
527+
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
528+
return color
529+
428530
def results(self, save_path: str = None):
429531
out = self.mot_eval()
430532

compressai_vision/pipelines/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
import torch
4040
import torch.nn as nn
41+
from omegaconf.errors import InterpolationResolutionError
4142
from torch import Tensor
4243

4344
from compressai_vision.codecs.utils import (
@@ -81,6 +82,15 @@ def __init__(
8182
self.bitstream_name = self.configs["codec"]["bitstream_name"]
8283
self._output_ext = ".h5"
8384

85+
try:
86+
vis_flag = self.configs["visualization"].save_visualization
87+
except InterpolationResolutionError:
88+
vis_flag = False
89+
if vis_flag:
90+
self.vis_dir = self.configs["visualization"].visualization_dir
91+
self.vis_threshold = self.configs["visualization"].get("threshold", None)
92+
self._create_folder(self.vis_dir)
93+
8494
self.codec_output_dir = Path(self.configs["codec"]["codec_output_dir"])
8595
self.is_mac_calculation = self.configs["codec"]["measure_complexity"]
8696
self._create_folder(self.codec_output_dir)

compressai_vision/pipelines/fo_vcm/conversion/detectron2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
"""From 51 dataset into Detectron2-compatible dataset
31-
"""
30+
"""From 51 dataset into Detectron2-compatible dataset"""
3231
from math import floor
3332

3433
# import cv2

compressai_vision/pipelines/fo_vcm/patch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
"""monkey-patching for https://github.com/voxel51/fiftyone/issues/2096
31-
"""
30+
"""monkey-patching for https://github.com/voxel51/fiftyone/issues/2096"""
3231
import csv
3332

3433
# import importhook # this module simply ...ks up everything (at least torch imports)

compressai_vision/pipelines/remote_inference/image_remote_inference.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ def __call__(
155155
end = time_measure()
156156
timing["nn_task"].append((end - start))
157157

158+
if getattr(self, "vis_dir", None) and hasattr(
159+
evaluator, "save_visualization"
160+
):
161+
evaluator.save_visualization(d, pred, self.vis_dir, self.vis_threshold)
162+
158163
evaluator.digest(d, pred)
159164

160165
out_res = d[0].copy()

compressai_vision/pipelines/remote_inference/video_remote_inference.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,11 @@ def __call__(
184184
end = time_measure()
185185
timing["nn_task"].append((end - start))
186186

187+
if getattr(self, "vis_dir", None) and hasattr(
188+
evaluator, "save_visualization"
189+
):
190+
evaluator.save_visualization(d, pred, self.vis_dir, self.vis_threshold)
191+
187192
evaluator.digest(d, pred)
188193

189194
out_res = d[0].copy()

compressai_vision/pipelines/split_inference/image_split_inference.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ def __call__(
211211

212212
if evaluator:
213213
evaluator.digest(d, pred)
214+
if getattr(self, "vis_dir", None) and hasattr(
215+
evaluator, "save_visualization"
216+
):
217+
evaluator.save_visualization(
218+
d, pred, self.vis_dir, self.vis_threshold
219+
)
214220

215221
out_res = d[0].copy()
216222
del (

compressai_vision/pipelines/split_inference/video_split_inference.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030

3131
import os
32+
from itertools import repeat
3233
from typing import Dict, List, Tuple, TypeVar
3334

3435
import torch
@@ -277,7 +278,14 @@ def __call__(
277278
self.logger.info("Processing NN-Part2...")
278279
output_list = []
279280

280-
for e, ftensors in enumerate(tqdm(dec_ftensors_list)):
281+
if getattr(self, "vis_dir", None):
282+
dec_ftensors_list = zip(dec_ftensors_list, dataloader)
283+
else:
284+
dec_ftensors_list = zip(dec_ftensors_list, repeat(None))
285+
286+
for e, (ftensors, d) in enumerate(
287+
tqdm(dec_ftensors_list, total=len(dataloader))
288+
):
281289
data = {k: v.to(self.device_nn_part2) for k, v in ftensors.items()}
282290
dec_features["data"] = data
283291
dec_features["file_name"] = file_names[e]
@@ -302,7 +310,12 @@ def __call__(
302310

303311
if evaluator:
304312
evaluator.digest(gt_inputs[e], pred)
305-
313+
if getattr(self, "vis_dir", None) and hasattr(
314+
evaluator, "save_visualization"
315+
):
316+
evaluator.save_visualization(
317+
d, pred, self.vis_dir, self.vis_threshold
318+
)
306319
out_res = dec_features.copy()
307320
del (out_res["data"], out_res["org_input_size"])
308321

@@ -343,7 +356,7 @@ def __call__(
343356

344357
@staticmethod
345358
def _feature_tensor_list_to_dict(
346-
data: List[Dict[str, Tensor]]
359+
data: List[Dict[str, Tensor]],
347360
) -> Dict[str, Tensor]:
348361
"""
349362
Converts a list of feature tensors into a dictionary format.

0 commit comments

Comments
 (0)