Skip to content

Commit 83309f4

Browse files
jianwensongfracape
authored andcommitted
[feat] hyperparameter settings in model wrapper
1 parent 740c852 commit 83309f4

File tree

8 files changed

+68
-42
lines changed

8 files changed

+68
-42
lines changed

cfgs/vision_model/default.yaml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ faster_rcnn_X_101_32x8d_FPN_3x:
1515
weights: "weights/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
1616
integer_conv_weight: False
1717
splits : "fpn" #, "c2" or "r2"
18+
hyper_params:
19+
update: False
20+
conf_threshold: 0.05
21+
max_dets: 100
1822

1923
mask_rcnn_R_50_FPN_3x:
2024
model_path_prefix: ${..model_root_path}
@@ -48,12 +52,9 @@ jde_1088x608:
4852
cfg: "models/Towards-Realtime-MOT/cfg/yolov3_1088x608.cfg"
4953
weights: "weights/jde/jde.1088x608.uncertainty.pt"
5054
integer_conv_weight: False
51-
iou_thres: 0.5
52-
conf_thres: 0.5
53-
nms_thres: 0.4
54-
min_box_area: 200
55-
track_buffer: 30
56-
frame_rate: 30 # It is odd to consider this at here but following original code.
55+
hyper_params:
56+
update: False
57+
conf_threshold: 0.5
5758
splits : [36, 61, 74] # MPEG FCM TEST with JDE on TVD
5859
#splits : [105, 90, 75] # MPEG FCM TEST with JDE on HiEve
5960

compressai_vision/evaluators/evaluators.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,9 +717,7 @@ def digest(self, gt, pred, mse_results=None):
717717
pred_list = []
718718
for tlwh, id in zip(pred["tlwhs"], pred["ids"]):
719719
x1, y1, w, h = tlwh
720-
if (
721-
self.apply_pred_offset
722-
): # Replicate offset applied in load_motchallenge() in motmetrics library, used in VCM eval framework to load predictions from disk
720+
if self.apply_pred_offset: # Replicate offset applied in load_motchallenge() in motmetrics library, used in VCM eval framework to load predictions from disk
723721
x1 -= 1
724722
y1 -= 1
725723
# x2, y2 = x1 + w, y1 + h

compressai_vision/model_wrappers/detectron2.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

2929
import re
30+
3031
from enum import Enum
3132
from pathlib import Path
3233
from typing import Dict, List, Optional
3334

3435
import torch
36+
3537
from detectron2.checkpoint import DetectionCheckpointer
3638
from detectron2.config import get_cfg
3739
from detectron2.modeling import build_model
@@ -209,6 +211,15 @@ def __init__(self, device: str, **kwargs):
209211
zip(self.split_layer_list, [None] * len(self.split_layer_list))
210212
)
211213

214+
if kwargs.get("hyper_params", {}).get("update", False):
215+
hyper_params = {
216+
"conf_threshold": kwargs.get("hyper_params", {}).get(
217+
"conf_threshold", None
218+
),
219+
"max_dets": kwargs.get("hyper_params", {}).get("max_dets", None),
220+
}
221+
self._apply_infer_overrides(hyper_params)
222+
212223
assert self.top_block is not None
213224
assert self.proposal_generator is not None
214225

@@ -313,9 +324,13 @@ def _apply_infer_overrides(self, overrides: Dict):
313324
"""Overrides hyperparameters in roi_heads"""
314325

315326
box_pred = getattr(self.roi_heads, "box_predictor", None)
316-
if "conf_threshold" in overrides and hasattr(box_pred, "test_score_thresh"):
327+
if overrides.get("conf_threshold") is not None and hasattr(
328+
box_pred, "test_score_thresh"
329+
):
317330
box_pred.test_score_thresh = float(overrides["conf_threshold"])
318-
if "max_dets" in overrides and hasattr(box_pred, "test_topk_per_image"):
331+
if overrides.get("max_dets") is not None and hasattr(
332+
box_pred, "test_topk_per_image"
333+
):
319334
box_pred.test_topk_per_image = int(overrides["max_dets"])
320335

321336
@torch.no_grad()

compressai_vision/model_wrappers/jde.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

3030
import logging
31+
3132
from pathlib import Path
3233
from typing import Dict, List, Optional
3334

3435
import jde
3536
import torch
37+
3638
from jde.models import Darknet
3739
from jde.tracker import matching
3840
from jde.tracker.basetrack import TrackState
@@ -78,12 +80,12 @@ def __init__(self, device: str, **kwargs):
7880
}
7981

8082
self.model_configs = {
81-
"iou_thres": float(kwargs["iou_thres"]),
82-
"conf_thres": float(kwargs["conf_thres"]),
83-
"nms_thres": float(kwargs["nms_thres"]),
84-
"min_box_area": int(kwargs["min_box_area"]),
85-
"track_buffer": int(kwargs["track_buffer"]),
86-
"frame_rate": float(kwargs["frame_rate"]),
83+
"iou_thres": float(kwargs.get("iou_thres", 0.5)),
84+
"conf_thres": float(kwargs.get("conf_thres", 0.5)),
85+
"nms_thres": float(kwargs.get("nms_thres", 0.4)),
86+
"min_box_area": int(kwargs.get("min_box_area", 200)),
87+
"track_buffer": int(kwargs.get("track_buffer", 30)),
88+
"frame_rate": float(kwargs.get("frame_rate", 30)),
8789
}
8890
self.max_time_on_hold = int(
8991
self.model_configs["frame_rate"] / 30.0 * self.model_configs["track_buffer"]
@@ -116,6 +118,15 @@ def __init__(self, device: str, **kwargs):
116118
self.logger.level = kwargs["logging_level"]
117119
# logging.DEBUG
118120

121+
if kwargs.get("hyper_params", {}).get("update", False):
122+
hyper_params = {
123+
"conf_threshold": kwargs.get("hyper_params", {}).get(
124+
"conf_threshold", None
125+
),
126+
"max_dets": kwargs.get("hyper_params", {}).get("max_dets", None),
127+
}
128+
self._apply_infer_overrides(hyper_params)
129+
119130
# reset member variables to use over a sequence of frame
120131
self.reset()
121132

@@ -210,8 +221,7 @@ def _feature_pyramid_to_output(
210221
return {"tlwhs": online_tlwhs, "ids": online_ids}
211222

212223
def _apply_infer_overrides(self, overrides: Dict):
213-
214-
if "conf_threshold" in overrides:
224+
if overrides.get("conf_threshold") is not None:
215225
self.model_configs["conf_thres"] = float(overrides["conf_threshold"])
216226

217227
@torch.no_grad()
@@ -337,9 +347,7 @@ def _jde_process(self, pred, org_img_size: tuple, input_img_size: tuple):
337347

338348
detections = [detections[i] for i in u_detection]
339349
# detections is now a list of the unmatched detections
340-
r_tracked_stracks = (
341-
[]
342-
) # This is container for stracks which were tracked till the
350+
r_tracked_stracks = [] # This is container for stracks which were tracked till the
343351
# previous frame but no detection was found for it in the current frame
344352
for i in u_track:
345353
if track_candidates_pool[i].state == TrackState.Tracked:

compressai_vision/pipelines/base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@
3131
import json
3232
import logging
3333
import os
34+
3435
from enum import Enum
3536
from pathlib import Path
3637
from typing import Callable, Dict, Tuple
3738
from uuid import uuid4 as uuid
3839

3940
import torch
4041
import torch.nn as nn
42+
4143
from omegaconf.errors import InterpolationResolutionError
4244
from torch import Tensor
4345

@@ -180,9 +182,7 @@ def _update_codec_configs_at_pipeline_level(self, total_num_frames):
180182
if n_frames_to_be_encoded == -1:
181183
n_frames_to_be_encoded = total_num_frames
182184

183-
assert (
184-
n_frames_to_be_encoded
185-
), f"Number of frames to be encoded must be greater than 0, but got {n_frames_to_be_encoded}"
185+
assert n_frames_to_be_encoded, f"Number of frames to be encoded must be greater than 0, but got {n_frames_to_be_encoded}"
186186

187187
if (self._codec_skip_n_frames + n_frames_to_be_encoded) > total_num_frames:
188188
self.logger.warning(
@@ -200,7 +200,9 @@ def _update_codec_configs_at_pipeline_level(self, total_num_frames):
200200
self._codec_skip_n_frames > 0
201201
or self._codec_n_frames_to_be_encoded != total_num_frames
202202
):
203-
assert self.configs["codec"][
203+
assert self.configs[
204+
"codec"
205+
][
204206
"encode_only"
205207
], "Encoding part of a sequence is only available when `codec.encode_only' is True"
206208

@@ -220,8 +222,8 @@ def _prep_features_to_dump(features, n_bits, datacatalog_name):
220222
assert (
221223
n_bits == 8 or n_bits == 16
222224
), "currently it only supports dumping features in 8 bits or 16 bits"
223-
assert datacatalog_name in list(
224-
MIN_MAX_DATASET.keys()
225+
assert (
226+
datacatalog_name in list(MIN_MAX_DATASET.keys())
225227
), f"{datacatalog_name} does not exist in the pre-computed minimum and maximum tables"
226228
minv, maxv = MIN_MAX_DATASET[datacatalog_name]
227229
data_features = {}
@@ -259,8 +261,8 @@ def _post_process_loaded_features(features, n_bits, datacatalog_name):
259261
assert (
260262
n_bits == 8 or n_bits == 16
261263
), "currently it only supports dumping features in 8 bits or 16 bits"
262-
assert datacatalog_name in list(
263-
MIN_MAX_DATASET.keys()
264+
assert (
265+
datacatalog_name in list(MIN_MAX_DATASET.keys())
264266
), f"{datacatalog_name} does not exist in the pre-computed minimum and maximum tables"
265267
minv, maxv = MIN_MAX_DATASET[datacatalog_name]
266268
data_features = {}
@@ -488,13 +490,11 @@ def calc_feature_mse(
488490
input_feats: Dict[str, torch.Tensor],
489491
recon_feats: Dict[str, torch.Tensor],
490492
) -> Dict[str, float]:
491-
492493
mse_results: Dict[str, float] = {}
493494

494495
keys_recon = list(recon_feats.keys())
495496

496497
for i, key in enumerate(input_feats.keys()):
497-
498498
x = input_feats[key].cpu()
499499
y = recon_feats[keys_recon[i]].cpu()
500500

compressai_vision/pipelines/split_inference/image_split_inference.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

3030
import os
31+
3132
from typing import Dict
3233

3334
import torch
35+
3436
from torch.utils.data import DataLoader
3537
from tqdm import tqdm
3638

@@ -210,19 +212,17 @@ def __call__(
210212
self.update_time_elapsed("nn_part_2", (time_measure() - start))
211213

212214
if evaluator:
213-
mse_results = None
214-
if (
215+
mse_enabled = (
215216
evaluator.calculate_feature_mse
216217
and not self.configs["codec"]["decode_only"]
217-
):
218-
mse_results = self.calc_feature_mse(
219-
featureT["data"], dec_features["data"]
220-
)
218+
)
219+
mse_results = (
220+
self.calc_feature_mse(featureT["data"], dec_features["data"])
221+
if mse_enabled
222+
else None
223+
)
221224

222-
if mse_results:
223-
evaluator.digest(d, pred, mse_results)
224-
else:
225-
evaluator.digest(d, pred)
225+
evaluator.digest(d, pred, mse_results)
226226

227227
if getattr(self, "vis_dir", None) and hasattr(
228228
evaluator, "save_visualization"

compressai_vision/pipelines/split_inference/video_split_inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929

3030

3131
import os
32+
3233
from itertools import repeat
3334
from typing import Dict, List, Tuple, TypeVar
3435

3536
import torch
37+
3638
from torch import Tensor
3739
from torch.utils.data import DataLoader
3840
from tqdm import tqdm

compressai_vision/run/eval_split_inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@
4545

4646
import logging
4747
import os
48+
4849
from pathlib import Path
4950
from typing import Any
5051

5152
import hydra
5253
import pandas as pd
54+
5355
from omegaconf import DictConfig
5456
from tabulate import tabulate
5557

0 commit comments

Comments
 (0)