Skip to content

Commit 5f7b95a

Browse files
committed
[feat] visualization support for NN part 2
1 parent 53fbf1c commit 5f7b95a

File tree

8 files changed

+119
-4
lines changed

8 files changed

+119
-4
lines changed

cfgs/pipeline/remote_inference.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,7 @@ evaluation:
3333
bypass: False
3434
dump: True
3535
evaluation_dir: "${codec.output_dir}/evaluation"
36+
visualization:
37+
save_visualization: "${codec.save_visualization}"
38+
visualization_dir: "${codec.output_dir}/visualization"
39+
threshold: 0 # only for detectron2, 0 means default setting of detectron2

cfgs/pipeline/split_inference.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,7 @@ evaluation:
4444
bypass: False
4545
dump: True
4646
evaluation_dir: "${codec.output_dir}/evaluation"
47+
visualization:
48+
save_visualization: "${codec.save_visualization}"
49+
visualization_dir: "${codec.output_dir}/visualization"
50+
threshold: 0 # only for detectron2, 0 means default setting of detectron2

compressai_vision/evaluators/evaluators.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@
5252
from compressai_vision.registry import register_evaluator
5353
from compressai_vision.utils import time_measure, to_cpu
5454

55+
from detectron2.utils.visualizer import Visualizer
56+
from detectron2.data import MetadataCatalog
57+
import cv2
58+
import os
59+
5560
from .base_evaluator import BaseEvaluator
5661
from .tf_evaluation_utils import (
5762
DetectionResultFields,
@@ -94,6 +99,27 @@ def reset(self):
9499

95100
def digest(self, gt, pred):
96101
return self._evaluator.process(gt, pred)
102+
103+
def save_visualization(self, gt, pred, output_dir, threshold):
104+
gt_image = gt[0]["image"]
105+
if torch.is_floating_point(gt_image):
106+
gt_image = (gt_image * 255).clamp(0, 255).to(torch.uint8)
107+
gt_image = gt_image[[2, 1, 0], ...]
108+
gt_image = gt_image.permute(1, 2, 0).cpu().numpy()
109+
gt_image = cv2.resize(gt_image, (gt[0]["width"], gt[0]["height"]))
110+
111+
img_id = gt[0]["image_id"]
112+
metadata = MetadataCatalog.get(self.dataset_name)
113+
instances = pred[0]["instances"].to("cpu")
114+
if threshold:
115+
keep = instances.scores >= threshold
116+
instances = instances[keep]
117+
118+
v = Visualizer(gt_image[:, :, ::-1], metadata, scale=1)
119+
out = v.draw_instance_predictions(instances) #selected_instances for specific class
120+
output_path = os.path.join(output_dir, f"{img_id}.jpg")
121+
cv2.imwrite(output_path, out.get_image()[:, :, ::-1])
122+
return
97123

98124
def results(self, save_path: str = None):
99125
out = self._evaluator.evaluate()
@@ -259,6 +285,28 @@ def digest(self, gt, pred):
259285

260286
return
261287

288+
def save_visualization(self, gt, pred, output_dir, threshold):
289+
gt_image = gt[0]["image"]
290+
if torch.is_floating_point(gt_image):
291+
gt_image = (gt_image * 255).clamp(0, 255).to(torch.uint8)
292+
gt_image = gt_image[[2, 1, 0], ...]
293+
gt_image = gt_image.permute(1, 2, 0).cpu().numpy()
294+
gt_image = cv2.resize(gt_image, (gt[0]["width"], gt[0]["height"]))
295+
296+
img_id = gt[0]["image_id"]
297+
metadata = MetadataCatalog.get(self.dataset_name)
298+
instances = pred[0]["instances"].to("cpu")
299+
300+
if threshold:
301+
keep = instances.scores >= threshold
302+
instances = instances[keep]
303+
304+
v = Visualizer(gt_image[:, :, ::-1], metadata, scale=1)
305+
out = v.draw_instance_predictions(instances) #selected_instances for specific class
306+
output_path = os.path.join(output_dir, f"{img_id}.jpg")
307+
cv2.imwrite(output_path, out.get_image()[:, :, ::-1])
308+
return
309+
262310
def _process_prediction(self, pred_dict):
263311
valid_cls = []
264312
valid_scores = []
@@ -425,6 +473,42 @@ def digest(self, gt, pred):
425473

426474
self._predictions[int(gt[0]["image_id"])] = pred_list
427475

476+
def save_visualization(self, gt, pred, output_dir, threshold):
477+
image_id = gt[0]["image_id"]
478+
gt_image = gt[0]["image"].permute(1, 2, 0).cpu().numpy()
479+
gt_image = (gt_image * 255).astype(np.uint8)
480+
gt_image = cv2.resize(cv2.cvtColor(gt_image, cv2.COLOR_RGB2BGR), (gt[0]["width"], gt[0]["height"]))
481+
online_im = self.plot_tracking(gt_image, pred["tlwhs"], pred["ids"], frame_id=image_id)
482+
output_path = os.path.join(output_dir, f"{image_id}.png")
483+
cv2.imwrite(output_path, online_im)
484+
return
485+
486+
def plot_tracking(self,image, tlwhs, obj_ids, scores=None, frame_id=0, fps=0., ids2=None):
487+
im = np.ascontiguousarray(np.copy(image))
488+
im_h, im_w = im.shape[:2]
489+
490+
text_scale = max(1, image.shape[1] / 1600.)
491+
text_thickness = 1 if text_scale > 1.1 else 1
492+
line_thickness = max(1, int(image.shape[1] / 500.))
493+
494+
for i, tlwh in enumerate(tlwhs):
495+
x1, y1, w, h = tlwh
496+
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
497+
obj_id = int(obj_ids[i])
498+
id_text = '{}'.format(int(obj_id))
499+
if ids2 is not None:
500+
id_text = id_text + ', {}'.format(int(ids2[i]))
501+
color = self.get_color(abs(obj_id))
502+
cv2.rectangle(im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
503+
cv2.putText(im, id_text, (intbox[0], intbox[1] + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255),
504+
thickness=text_thickness)
505+
return im
506+
507+
def get_color(self,idx):
508+
idx = idx * 3
509+
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
510+
return color
511+
428512
def results(self, save_path: str = None):
429513
out = self.mot_eval()
430514

compressai_vision/pipelines/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from pathlib import Path
3636
from typing import Callable, Dict
3737
from uuid import uuid4 as uuid
38-
38+
from omegaconf.errors import InterpolationResolutionError
3939
import torch
4040
import torch.nn as nn
4141
from torch import Tensor
@@ -81,6 +81,15 @@ def __init__(
8181
self.bitstream_name = self.configs["codec"]["bitstream_name"]
8282
self._output_ext = ".h5"
8383

84+
try:
85+
vis_flag = self.configs["visualization"].save_visualization
86+
except InterpolationResolutionError:
87+
vis_flag = False
88+
if vis_flag:
89+
self.vis_dir = self.configs["visualization"].visualization_dir
90+
self.vis_threshold = self.configs["visualization"].get('threshold', None)
91+
self._create_folder(self.vis_dir)
92+
8493
self.codec_output_dir = Path(self.configs["codec"]["codec_output_dir"])
8594
self.is_mac_calculation = self.configs["codec"]["measure_complexity"]
8695
self._create_folder(self.codec_output_dir)

compressai_vision/pipelines/remote_inference/image_remote_inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def __call__(
155155
end = time_measure()
156156
timing["nn_task"].append((end - start))
157157

158+
if getattr(self, "vis_dir", None) and hasattr(evaluator, 'save_visualization'):
159+
evaluator.save_visualization(d, pred, self.vis_dir, self.vis_threshold)
160+
158161
evaluator.digest(d, pred)
159162

160163
out_res = d[0].copy()

compressai_vision/pipelines/remote_inference/video_remote_inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def __call__(
184184
end = time_measure()
185185
timing["nn_task"].append((end - start))
186186

187+
if getattr(self, "vis_dir", None) and hasattr(evaluator, 'save_visualization'):
188+
evaluator.save_visualization(d, pred, self.vis_dir, self.vis_threshold)
189+
187190
evaluator.digest(d, pred)
188191

189192
out_res = d[0].copy()

compressai_vision/pipelines/split_inference/image_split_inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ def __call__(
211211

212212
if evaluator:
213213
evaluator.digest(d, pred)
214+
if getattr(self, "vis_dir", None) and hasattr(evaluator, 'save_visualization'):
215+
evaluator.save_visualization(d, pred, self.vis_dir, self.vis_threshold)
214216

215217
out_res = d[0].copy()
216218
del (

compressai_vision/pipelines/split_inference/video_split_inference.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from torch import Tensor
3636
from torch.utils.data import DataLoader
3737
from tqdm import tqdm
38-
38+
from itertools import repeat
3939
from compressai_vision.evaluators import BaseEvaluator
4040
from compressai_vision.model_wrappers import BaseWrapper
4141
from compressai_vision.registry import register_pipeline
@@ -277,7 +277,12 @@ def __call__(
277277
self.logger.info("Processing NN-Part2...")
278278
output_list = []
279279

280-
for e, ftensors in enumerate(tqdm(dec_ftensors_list)):
280+
if getattr(self, "vis_dir", None):
281+
dec_ftensors_list = zip(dec_ftensors_list, dataloader)
282+
else:
283+
dec_ftensors_list = zip(dec_ftensors_list, repeat(None))
284+
285+
for e, (ftensors, d) in enumerate(tqdm(dec_ftensors_list, total=len(dataloader))):
281286
data = {k: v.to(self.device_nn_part2) for k, v in ftensors.items()}
282287
dec_features["data"] = data
283288
dec_features["file_name"] = file_names[e]
@@ -302,7 +307,8 @@ def __call__(
302307

303308
if evaluator:
304309
evaluator.digest(gt_inputs[e], pred)
305-
310+
if getattr(self, "vis_dir", None) and hasattr(evaluator, 'save_visualization'):
311+
evaluator.save_visualization(d, pred, self.vis_dir, self.vis_threshold)
306312
out_res = dec_features.copy()
307313
del (out_res["data"], out_res["org_input_size"])
308314

0 commit comments

Comments
 (0)