Skip to content

Commit 4fb20be

Browse files
jianwensongfracape
authored andcommitted
[fix] classwise results
1 parent 5ee5d54 commit 4fb20be

File tree

4 files changed

+55
-62
lines changed

4 files changed

+55
-62
lines changed

compressai_vision/pipelines/remote_inference/video_remote_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __call__(
186186
start = time_measure()
187187
dec_d = {
188188
"file_name": dec_seq["file_names"][e],
189-
"file_origin": d[e]["file_name"],
189+
"file_origin": d[0]["file_name"],
190190
}
191191
# dec_d = {"file_name": dec_seq[0]["file_names"][e]}
192192
pred = vision_model.forward(org_map_func(dec_d))

scripts/metrics/compute_overall_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
],
6868
CLASSES[1]: ["BasketballDrill", "BQMall", "PartyScene", "RaceHorses_832x480"],
6969
CLASSES[2]: ["BasketballPass", "BQSquare", "BlowingBubbles", "RaceHorses"],
70-
CLASSES[3]: ["Traffic", "BQTerrace"]
70+
CLASSES[3]: ["Traffic", "BQTerrace"],
7171
}
7272

7373
SEQUENCE_TO_OFFSET = {

scripts/metrics/gen_mpeg_cttc_csv.py

Lines changed: 40 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,13 @@
5858
DATASETS = ["TVD", "SFU", "OIV6", "HIEVE", "PANDASET"]
5959

6060

61-
def read_df_rec(path, seq_list, nb_operation_points, fn_regex=r"summary.csv", prefix: str | None = None,):
61+
def read_df_rec(
62+
path,
63+
seq_list,
64+
nb_operation_points,
65+
fn_regex=r"summary.csv",
66+
prefix: str | None = None,
67+
):
6268
summary_csvs = [f for f in iglob(join(path, "**", fn_regex), recursive=True)]
6369
if nb_operation_points > 0:
6470
seq_names = [
@@ -163,10 +169,13 @@ def generate_csv_classwise_video_map(
163169
nb_operation_points: int = 4,
164170
no_cactus: bool = False,
165171
skip_classwise: bool = False,
166-
prefix: str | None = None,
172+
seq_prefix: str = None,
173+
dataset_prefix: str = None,
167174
):
168175
opts_metrics = {"AP": 0, "AP50": 1, "AP75": 2, "APS": 3, "APM": 4, "APL": 5}
169-
results_df = read_df_rec(result_path, seq_list, nb_operation_points, prefix=prefix)
176+
results_df = read_df_rec(
177+
result_path, seq_list, nb_operation_points, prefix=seq_prefix
178+
)
170179

171180
# sort
172181
sorterIndex = dict(zip(seq_list, range(len(seq_list))))
@@ -186,6 +195,13 @@ def generate_csv_classwise_video_map(
186195
classwise_name = list(seqs_by_class.keys())[0]
187196
classwise_seqs = list(seqs_by_class.values())[0]
188197

198+
cur_seq_prefix = (
199+
seq_prefix
200+
if seq_prefix
201+
and any(name.startswith(seq_prefix) for name in classwise_seqs)
202+
else None
203+
)
204+
189205
class_wise_maps = []
190206
for q in range(nb_operation_points):
191207
items = utils.search_items(
@@ -196,6 +212,8 @@ def generate_csv_classwise_video_map(
196212
BaseEvaluator.get_coco_eval_info_name,
197213
by_name=True,
198214
gt_folder=gt_folder,
215+
seq_prefix=cur_seq_prefix,
216+
dataset_prefix=dataset_prefix,
199217
)
200218

201219
assert (
@@ -211,7 +229,11 @@ def generate_csv_classwise_video_map(
211229
matched_seq_names = []
212230
for seq_info in items:
213231
name, _, _ = get_seq_info(seq_info[utils.SEQ_INFO_KEY])
214-
matched_seq_names.append(name)
232+
matched_seq_names.append(
233+
f"{seq_prefix}{name}"
234+
if seq_prefix and seq_prefix in seq_info[utils.SEQ_NAME_KEY]
235+
else name
236+
)
215237

216238
class_wise_results_df = generate_classwise_df(
217239
results_df, {classwise_name: matched_seq_names}
@@ -220,7 +242,7 @@ def generate_csv_classwise_video_map(
220242

221243
output_df = df_append(output_df, class_wise_results_df)
222244

223-
return output_df, results_df
245+
return output_df
224246

225247

226248
def generate_csv_classwise_video_mota(
@@ -436,6 +458,7 @@ def generate_csv(result_path, seq_list, nb_operation_points):
436458

437459
if args.dataset_name == "SFU":
438460
metric = args.metric
461+
dataset_prefix = "sfu-hw-"
439462
class_ab = {
440463
"CLASS-AB": [
441464
"Traffic",
@@ -496,6 +519,13 @@ def generate_csv(result_path, seq_list, nb_operation_points):
496519
ns_seq_list = ["ns_Traffic_2560x1600_30", "ns_BQTerrace_1920x1080_60"]
497520
seq_list.extend(ns_seq_list)
498521
seq_prefix = "ns_"
522+
class_ab_star = {
523+
"CLASS-AB*": [
524+
"ns_Traffic",
525+
"ns_BQTerrace",
526+
]
527+
}
528+
classes.append(class_ab_star)
499529

500530
if args.mode == "VCM" and not args.include_optional:
501531
seq_list.remove("Kimono_1920x1080_24")
@@ -504,7 +534,7 @@ def generate_csv(result_path, seq_list, nb_operation_points):
504534
if args.mode == "FCM" and args.no_cactus:
505535
seq_list.remove("Cactus_1920x1080_50")
506536

507-
output_df, results_df = generate_csv_classwise_video_map(
537+
output_df = generate_csv_classwise_video_map(
508538
norm_result_path,
509539
args.dataset_path,
510540
classes,
@@ -514,7 +544,10 @@ def generate_csv(result_path, seq_list, nb_operation_points):
514544
args.nb_operation_points,
515545
args.no_cactus,
516546
args.mode == "VCM", # skip classwise evaluation
517-
prefix=seq_prefix if "seq_prefix" in locals() else None, # adding prefix to non-scale sequence
547+
seq_prefix=seq_prefix
548+
if "seq_prefix" in locals()
549+
else None, # adding prefix to non-scale sequence
550+
dataset_prefix=dataset_prefix if "dataset_prefix" in locals() else None,
518551
)
519552

520553
if args.mode == "VCM":
@@ -524,50 +557,6 @@ def generate_csv(result_path, seq_list, nb_operation_points):
524557
perf_name="end_accuracy",
525558
rate_name="bitrate (kbps)",
526559
)
527-
else:
528-
# add CLASS-AB* using ns_* results for Traffic and BQTerrace
529-
if args.add_non_scale:
530-
class_ab_star = {
531-
"CLASS-AB*": [
532-
"ns_Traffic",
533-
"ns_BQTerrace",
534-
]
535-
}
536-
# Compute classwise mAP for AB* using ns_* eval results but original GT
537-
class_wise_maps = []
538-
for q in range(args.nb_operation_points):
539-
items = utils.search_items(
540-
norm_result_path,
541-
args.dataset_path,
542-
q,
543-
list(class_ab_star.values())[0],
544-
BaseEvaluator.get_coco_eval_info_name,
545-
by_name=True,
546-
gt_folder=args.gt_folder,
547-
gt_name_overrides={
548-
"ns_Traffic": "Traffic",
549-
"ns_BQTerrace": "BQTerrace",
550-
},
551-
)
552-
553-
assert (
554-
len(items) > 0
555-
), "No evaluation information found for CLASS-AB* in provided result directories..."
556-
557-
summary = compute_overall_mAP("CLASS-AB*", items)
558-
maps = summary.values[0][{"AP": 0, "AP50": 1}[metric]]
559-
class_wise_maps.append(maps)
560-
561-
matched_seq_names = []
562-
for seq_info in items:
563-
name, _, _ = get_seq_info(seq_info[utils.SEQ_INFO_KEY])
564-
matched_seq_names.append(name)
565-
566-
class_wise_results_df = generate_classwise_df(
567-
results_df, {"CLASS-AB*": matched_seq_names}
568-
)
569-
class_wise_results_df["end_accuracy"] = class_wise_maps
570-
output_df = df_append(output_df, class_wise_results_df)
571560
elif args.dataset_name == "OIV6":
572561
output_df = generate_csv(
573562
norm_result_path, ["MPEGOIV6"], args.nb_operation_points

scripts/metrics/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import re
4141

4242
from pathlib import Path
43-
from typing import Optional, Dict
43+
from typing import Dict, Optional
4444

4545
__all__ = [
4646
"get_seq_number",
@@ -86,8 +86,10 @@ def get_eval_info_path_by_seq_num(seq_num, _path, qidx: int, name_func: callable
8686
return eval_info_path, dname
8787

8888

89-
def get_eval_info_path_by_seq_name(seq_name, _path, _qidx: int, name_func: callable):
90-
result = get_folder_path_by_seq_name(seq_name, _path)
89+
def get_eval_info_path_by_seq_name(
90+
seq_name, _path, _qidx: int, name_func: callable, dataset_prefix=None
91+
):
92+
result = get_folder_path_by_seq_name(seq_name, _path, dataset_prefix)
9193
if result is None:
9294
return
9395
eval_folder, _dname = result
@@ -156,9 +158,12 @@ def get_folder_path_by_seq_num(seq_num, _path):
156158
return None
157159

158160

159-
def get_folder_path_by_seq_name(seq_name, _path):
161+
def get_folder_path_by_seq_name(seq_name, _path, dataset_prefix=None):
160162
_folder_list = [f for f in Path(_path).iterdir() if f.is_dir()]
161163

164+
if dataset_prefix is not None:
165+
seq_name = f"{dataset_prefix}{seq_name}"
166+
162167
for _name in _folder_list:
163168
if seq_name in _name.stem:
164169
return _name.resolve(), _name.stem
@@ -175,13 +180,14 @@ def search_items(
175180
by_name=False,
176181
pandaset_flag=False,
177182
gt_folder="annotations",
178-
gt_name_overrides: Optional[Dict[str, str]] = None,
183+
seq_prefix=None,
184+
dataset_prefix=None,
179185
):
180186
_ret_list = []
181187
for seq_name in seq_list:
182188
if by_name is True:
183189
result = get_eval_info_path_by_seq_name(
184-
seq_name, result_path, rate_point, eval_func
190+
seq_name, result_path, rate_point, eval_func, dataset_prefix
185191
)
186192
if result is None:
187193
continue
@@ -192,9 +198,7 @@ def search_items(
192198
)
193199
else:
194200
_gt_lookup_name = (
195-
gt_name_overrides.get(seq_name, seq_name)
196-
if gt_name_overrides
197-
else seq_name
201+
seq_name.split(seq_prefix)[-1] if seq_prefix else seq_name
198202
)
199203
seq_info_path, seq_gt_path = get_seq_info_path_by_seq_name(
200204
_gt_lookup_name, dataset_path, gt_folder

0 commit comments

Comments
 (0)