Skip to content

Commit 78d53db

Browse files
committed
feat: enable to install a subset of vision models only
1 parent 71f1fb5 commit 78d53db

File tree

10 files changed

+105
-102
lines changed

10 files changed

+105
-102
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
*.bin
33
*.inc
44
*.tar.gz
5-
*.sh
65
.DS_Store
76
builds
87
compressai_vision/version.py

compressai_vision/datasets/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@
3535
import numpy as np
3636
import torch
3737

38-
from jde.utils.datasets import letterbox
39-
from mmpose.structures.bbox import get_warp_matrix
40-
from segment_anything.utils.transforms import ResizeLongestSide
4138
from torch.nn import functional as F
4239
from torchvision import transforms
4340

@@ -130,6 +127,7 @@ def __call__(self, dataset_dict):
130127
Returns:
131128
dict: a format that compressai-vision pipelines accept
132129
"""
130+
from mmpose.structures.bbox import get_warp_matrix
133131

134132
dataset_dict = copy.deepcopy(dataset_dict)
135133
# the copied dictionary will be modified by code below
@@ -284,6 +282,7 @@ def __call__(self, dataset_dict):
284282
Returns:
285283
dict: a format that compressai-vision pipelines accept
286284
"""
285+
from jde.utils.datasets import letterbox
287286

288287
dataset_dict = copy.deepcopy(dataset_dict)
289288
# the copied dictionary will be modified by code below
@@ -313,6 +312,8 @@ def __init__(self, img_size=1024):
313312
Args:
314313
img_size: single value - target size to SAM as input
315314
"""
315+
from segment_anything.utils.transforms import ResizeLongestSide
316+
316317
self.target_size = img_size
317318
self.transform = ResizeLongestSide(img_size)
318319

compressai_vision/evaluators/evaluators.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,8 @@
4141
import pandas as pd
4242
import torch
4343

44-
from detectron2.data import MetadataCatalog
45-
from detectron2.evaluation import COCOEvaluator
46-
from detectron2.utils.visualizer import Visualizer
47-
from jde.utils.io import unzip_objs
48-
from mmpose.datasets.datasets import BaseCocoStyleDataset
49-
from mmpose.datasets.transforms import PackPoseInputs
50-
from mmpose.evaluation.metrics import CocoMetric
51-
from pycocotools.coco import COCO
5244
from pytorch_msssim import ms_ssim
5345
from tqdm import tqdm
54-
from yolox.data.datasets.coco import remove_useless_info
55-
from yolox.evaluators import COCOEvaluator as YOLOX_COCOEvaluator
56-
from yolox.utils import xyxy2xywh
5746

5847
from compressai_vision.datasets import deccode_compressed_rle
5948
from compressai_vision.registry import register_evaluator
@@ -132,6 +121,8 @@ def __init__(
132121

133122
self.set_annotation_info(dataset)
134123

124+
from detectron2.evaluation import COCOEvaluator
125+
135126
self._evaluator = COCOEvaluator(
136127
dataset_name, False, output_dir=output_dir, use_fast_impl=False
137128
)
@@ -156,6 +147,9 @@ def save_visualization(self, gt, pred, output_dir, threshold):
156147
gt_image = cv2.resize(gt_image, (gt[0]["width"], gt[0]["height"]))
157148

158149
img_id = gt[0]["image_id"]
150+
from detectron2.data import MetadataCatalog
151+
from detectron2.utils.visualizer import Visualizer
152+
159153
metadata = MetadataCatalog.get(self.dataset_name)
160154
instances = pred[0]["instances"].to("cpu")
161155
if threshold:
@@ -343,6 +337,9 @@ def save_visualization(self, gt, pred, output_dir, threshold):
343337
gt_image = cv2.resize(gt_image, (gt[0]["width"], gt[0]["height"]))
344338

345339
img_id = gt[0]["image_id"]
340+
from detectron2.data import MetadataCatalog
341+
from detectron2.utils.visualizer import Visualizer
342+
346343
metadata = MetadataCatalog.get(self.dataset_name)
347344
instances = pred[0]["instances"].to("cpu")
348345

@@ -575,6 +572,10 @@ def __init__(
575572

576573
self.set_annotation_info(dataset)
577574

575+
from jde.utils.io import unzip_objs
576+
577+
self.unzip_objs = unzip_objs
578+
578579
mm.lap.default_solver = "lap"
579580
self.dataset = dataset.dataset
580581
self.eval_info_file_name = self.get_jde_eval_info_name(self.dataset_name)
@@ -734,13 +735,13 @@ def mot_eval(self):
734735
frm_id = int(gt_frame["image_id"])
735736

736737
pred_objs = self._predictions[frm_id].copy()
737-
pred_tlwhs, pred_ids, _ = unzip_objs(pred_objs)
738+
pred_tlwhs, pred_ids, _ = self.unzip_objs(pred_objs)
738739

739740
gt_objs = gt_frame["annotations"]["gt"].copy()
740-
gt_tlwhs, gt_ids, _ = unzip_objs(gt_objs)
741+
gt_tlwhs, gt_ids, _ = self.unzip_objs(gt_objs)
741742

742743
gt_ignore = gt_frame["annotations"]["gt_ignore"].copy()
743-
gt_ignore_tlwhs, _, _ = unzip_objs(gt_ignore)
744+
gt_ignore_tlwhs, _, _ = self.unzip_objs(gt_ignore)
744745

745746
# remove ignored results
746747
keep = np.ones(len(pred_tlwhs), dtype=bool)
@@ -913,6 +914,10 @@ def __init__(
913914
datacatalog_name, dataset_name, dataset, output_dir, eval_criteria
914915
)
915916

917+
from pycocotools.coco import COCO
918+
from yolox.data.datasets.coco import remove_useless_info
919+
from yolox.evaluators import COCOEvaluator as YOLOX_COCOEvaluator
920+
916921
self.set_annotation_info(dataset)
917922

918923
cocoapi = COCO(self.annotation_path)
@@ -1055,6 +1060,16 @@ def __init__(
10551060
dataset.get_org_mapper_func().compute_scale_and_center
10561061
)
10571062

1063+
try:
1064+
from mmpose.datasets.datasets import BaseCocoStyleDataset
1065+
from mmpose.datasets.transforms import PackPoseInputs
1066+
from mmpose.evaluation.metrics import CocoMetric
1067+
except ImportError:
1068+
self._logger.error(
1069+
"Failed to import mmpose. Please install it, e.g. with 'pip install mmpose'."
1070+
)
1071+
raise
1072+
10581073
if "metainfo" in args:
10591074
metainfo = args["metainfo"]
10601075
else:

compressai_vision/model_wrappers/__init__.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,34 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from .detectron2 import (
31-
BaseWrapper,
32-
faster_rcnn_R_50_FPN_3x,
33-
faster_rcnn_X_101_32x8d_FPN_3x,
34-
mask_rcnn_R_50_FPN_3x,
35-
mask_rcnn_X_101_32x8d_FPN_3x,
36-
panoptic_rcnn_R_101_FPN_3x,
37-
)
38-
from .jde import jde_1088x608
39-
from .rtmo import rtmo_multi_person_pose_estimation
40-
from .sam import (
41-
sam_vit_b_01ec64,
42-
sam_vit_h_4b8939,
43-
sam_vit_l_0b3195,
44-
)
45-
from .yolox import yolox_darknet53
30+
from .base_wrapper import BaseWrapper
31+
from . import detectron2, jde, rtmo, sam, yolox
4632

47-
__all__ = [
48-
"BaseWrapper",
49-
"faster_rcnn_X_101_32x8d_FPN_3x",
50-
"mask_rcnn_X_101_32x8d_FPN_3x",
51-
"faster_rcnn_R_50_FPN_3x",
52-
"mask_rcnn_R_50_FPN_3x",
53-
"panoptic_rcnn_R_101_FPN_3x",
54-
"jde_1088x608",
55-
"yolox_darknet53",
56-
"rtmo_multi_person_pose_estimation",
57-
"sam_vit_h_4b8939",
58-
"sam_vit_b_01ec64",
59-
"sam_vit_l_0b3195",
60-
]
33+
__all__ = ["BaseWrapper"]

compressai_vision/model_wrappers/detectron2.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,6 @@
3434

3535
import torch
3636

37-
from detectron2.checkpoint import DetectionCheckpointer
38-
from detectron2.config import get_cfg
39-
from detectron2.modeling import build_model
40-
from detectron2.modeling.meta_arch.panoptic_fpn import (
41-
combine_semantic_and_instance_outputs,
42-
detector_postprocess,
43-
sem_seg_postprocess,
44-
)
45-
from detectron2.structures import ImageList
4637

4738
from compressai_vision.registry import register_vision_model
4839

@@ -159,10 +150,16 @@ def __str__(self):
159150

160151
class Rcnn_R_50_X_101_FPN(BaseWrapper):
161152
def __init__(self, device: str, **kwargs):
153+
from detectron2.checkpoint import DetectionCheckpointer
154+
from detectron2.config import get_cfg
155+
from detectron2.modeling import build_model
156+
162157
super().__init__(device)
163158

164159
self._cfg = get_cfg()
165160
self._cfg.MODEL.DEVICE = device
161+
self.DetectionCheckpointer = DetectionCheckpointer
162+
self.build_model = build_model
166163
_path_prefix = (
167164
f"{root_path}"
168165
if kwargs["model_path_prefix"] == "default"
@@ -171,11 +168,11 @@ def __init__(self, device: str, **kwargs):
171168
self._cfg.merge_from_file(f"{_path_prefix}/{kwargs['cfg']}")
172169
_integer_conv_weight = bool(kwargs["integer_conv_weight"])
173170

174-
self.model = build_model(self._cfg)
171+
self.model = self.build_model(self._cfg)
175172
self.replace_conv2d_modules(self.model)
176173
self.model = self.model.to(device).eval()
177174

178-
DetectionCheckpointer(self.model).load(f"{_path_prefix}/{kwargs['weights']}")
175+
self.DetectionCheckpointer(self.model).load(f"{_path_prefix}/{kwargs['weights']}")
179176

180177
for param in self.model.parameters():
181178
param.requires_grad = False
@@ -271,6 +268,8 @@ def quantize_weights(model):
271268
return model
272269

273270
def input_resize(self, images: List):
271+
from detectron2.structures import ImageList
272+
274273
return ImageList.from_tensors(images, self.size_divisibility)
275274

276275
def input_to_features(self, x, device: str) -> Dict:
@@ -540,7 +539,18 @@ def __init__(self, device: str, **kwargs):
540539
@register_vision_model("panoptic_rcnn_R_101_FPN_3x")
541540
class panoptic_rcnn_R_101_FPN_3x(Rcnn_R_50_X_101_FPN):
542541
def __init__(self, device="cpu", **kwargs):
542+
from detectron2.modeling.meta_arch.panoptic_fpn import (
543+
combine_semantic_and_instance_outputs,
544+
detector_postprocess,
545+
sem_seg_postprocess,
546+
)
547+
543548
super().__init__(device, **kwargs)
549+
self.sem_seg_postprocess = sem_seg_postprocess
550+
self.detector_postprocess = detector_postprocess
551+
self.combine_semantic_and_instance_outputs = (
552+
combine_semantic_and_instance_outputs
553+
)
544554
self.sem_seg_head = self.model.sem_seg_head
545555

546556
combine_overlap_thresh = 0.5
@@ -590,12 +600,12 @@ def __init__(self, img_size: list):
590600
):
591601
height = input_per_image["height"]
592602
width = input_per_image["width"]
593-
sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width)
594-
detector_r = detector_postprocess(detector_result, height, width)
603+
sem_seg_r = self.sem_seg_postprocess(sem_seg_result, image_size, height, width)
604+
detector_r = self.detector_postprocess(detector_result, height, width)
595605

596606
processed_results.append({"sem_seg": sem_seg_r, "instances": detector_r})
597607

598-
panoptic_r = combine_semantic_and_instance_outputs(
608+
panoptic_r = self.combine_semantic_and_instance_outputs(
599609
detector_r,
600610
sem_seg_r.argmax(dim=0),
601611
self.combine_overlap_thresh,

compressai_vision/model_wrappers/jde.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,30 +32,12 @@
3232
from pathlib import Path
3333
from typing import Dict, List
3434

35-
import jde
3635
import torch
3736

38-
from jde.models import Darknet
39-
from jde.tracker import matching
40-
from jde.tracker.basetrack import TrackState
41-
from jde.tracker.multitracker import (
42-
STrack,
43-
joint_stracks,
44-
remove_duplicate_stracks,
45-
sub_stracks,
46-
)
47-
from jde.utils.kalman_filter import KalmanFilter
48-
from jde.utils.utils import non_max_suppression, scale_coords
49-
5037
from compressai_vision.registry import register_vision_model
5138

5239
from .base_wrapper import BaseWrapper
5340

54-
# Patch in modified create_modules
55-
from .jde_lowlevel import create_modules
56-
57-
jde.models.create_modules = create_modules
58-
5941
__all__ = [
6042
"jde_1088x608",
6143
]
@@ -67,8 +49,19 @@
6749
@register_vision_model("jde_1088x608")
6850
class jde_1088x608(BaseWrapper):
6951
def __init__(self, device: str, **kwargs):
52+
import jde
53+
from jde.models import Darknet
54+
from jde.utils.kalman_filter import KalmanFilter
55+
56+
from .jde_lowlevel import create_modules
57+
58+
jde.models.create_modules = create_modules
59+
7060
super().__init__(device)
7161

62+
self.Darknet = Darknet
63+
self.KalmanFilter = KalmanFilter
64+
7265
_path_prefix = (
7366
f"{root_path}"
7467
if kwargs["model_path_prefix"] == "default"
@@ -99,7 +92,7 @@ def __init__(self, device: str, **kwargs):
9992
zip(self.split_layer_list, [None] * len(self.split_layer_list))
10093
)
10194

102-
self.darknet = Darknet(self.model_info["cfg"], device, nID=14455)
95+
self.darknet = self.Darknet(self.model_info["cfg"], device, nID=14455)
10396
self.darknet.load_state_dict(
10497
torch.load(self.model_info["weights"], map_location="cpu")["model"],
10598
strict=False,
@@ -112,7 +105,7 @@ def __init__(self, device: str, **kwargs):
112105
if _integer_conv_weight:
113106
self.darknet = self.quantize_weights(self.darknet)
114107

115-
self.kalman_filter = KalmanFilter()
108+
self.kalman_filter = self.KalmanFilter()
116109

117110
if "logging_level" in kwargs:
118111
self.logger.level = kwargs["logging_level"]
@@ -219,6 +212,15 @@ def deeper_features_for_accuracy_proxy(self, x: Dict):
219212
# return x_deeper
220213

221214
def _jde_process(self, pred, org_img_size: tuple, input_img_size: tuple):
215+
from jde.tracker import matching
216+
from jde.tracker.basetrack import TrackState
217+
from jde.tracker.multitracker import (
218+
STrack,
219+
joint_stracks,
220+
remove_duplicate_stracks,
221+
sub_stracks,
222+
)
223+
from jde.utils.utils import non_max_suppression, scale_coords
222224
r"""Re-implementation of JDE from Z. Wang, L. Zheng, Y. Liu, and S. Wang:
223225
: `"Towards Real-Time Multi-Object Tracking"`_,
224226
The European Conference on Computer Vision (ECCV), 2020

0 commit comments

Comments
 (0)