Skip to content

Commit 3208fdd

Browse files
jianwensongfracape
authored andcommitted
[fix] mse error for other inferences
1 parent 772edba commit 3208fdd

File tree

2 files changed

+45
-32
lines changed

2 files changed

+45
-32
lines changed

compressai_vision/evaluators/evaluators.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,10 @@ def results(self, save_path: str = None):
176176

177177
self.write_results(out)
178178

179-
overall_mse = None
179+
# summary = {}
180+
# for key, item_dict in out.items():
181+
# summary[f"{key}"] = item_dict["AP"]
182+
180183
if self._mse_results:
181184
mse_results_dict = {"per_frame_mse": self._mse_results}
182185
overall_mse = 0.0
@@ -205,11 +208,9 @@ def results(self, save_path: str = None):
205208
) as f:
206209
json.dump(mse_results_dict, f, ensure_ascii=False, indent=4)
207210

208-
# summary = {}
209-
# for key, item_dict in out.items():
210-
# summary[f"{key}"] = item_dict["AP"]
211-
212-
return out, overall_mse
211+
return out, overall_mse
212+
else:
213+
return out
213214

214215

215216
@register_evaluator("OIC-EVAL")
@@ -474,7 +475,11 @@ def results(self, save_path: str = None):
474475

475476
self.write_results(out)
476477

477-
overall_mse = None
478+
summary = {}
479+
for key, value in out.items():
480+
name = "-".join(key.split("/")[1:])
481+
summary[name] = value
482+
478483
if self._mse_results:
479484
mse_results_dict = {"per_frame_mse": self._mse_results}
480485
overall_mse = 0.0
@@ -503,12 +508,9 @@ def results(self, save_path: str = None):
503508
) as f:
504509
json.dump(mse_results_dict, f, ensure_ascii=False, indent=4)
505510

506-
summary = {}
507-
for key, value in out.items():
508-
name = "-".join(key.split("/")[1:])
509-
summary[name] = value
510-
511-
return summary, overall_mse
511+
return summary, overall_mse
512+
else:
513+
return summary
512514

513515

514516
@register_evaluator("SEMANTICSEG-EVAL")
@@ -607,7 +609,6 @@ def results(self, save_path: str = None):
607609

608610
self.write_results(class_mIoU)
609611

610-
overall_mse = None
611612
if self._mse_results:
612613
mse_results_dict = {"per_frame_mse": self._mse_results}
613614
overall_mse = 0.0
@@ -636,7 +637,9 @@ def results(self, save_path: str = None):
636637
) as f:
637638
json.dump(mse_results_dict, f, ensure_ascii=False, indent=4)
638639

639-
return class_mIoU, overall_mse
640+
return class_mIoU, overall_mse
641+
else:
642+
return class_mIoU
640643

641644

642645
@register_evaluator("MOT-JDE-EVAL")
@@ -788,7 +791,6 @@ def results(self, save_path: str = None):
788791

789792
self.write_results(out)
790793

791-
overall_mse = None
792794
if self._mse_results:
793795
mse_results_dict = {"per_frame_mse": self._mse_results}
794796
overall_mse = 0.0
@@ -817,7 +819,9 @@ def results(self, save_path: str = None):
817819
) as f:
818820
json.dump(mse_results_dict, f, ensure_ascii=False, indent=4)
819821

820-
return out, overall_mse
822+
return out, overall_mse
823+
else:
824+
return out
821825

822826
@staticmethod
823827
def digest_summary(summary):
@@ -1126,7 +1130,10 @@ def results(self, save_path: str = None):
11261130

11271131
self.write_results(eval_results)
11281132

1129-
overall_mse = None
1133+
*listed_items, summary = eval_results
1134+
1135+
self._logger.info("\n" + summary)
1136+
11301137
if self._mse_results:
11311138
mse_results_dict = {"per_frame_mse": self._mse_results}
11321139
overall_mse = 0.0
@@ -1155,11 +1162,12 @@ def results(self, save_path: str = None):
11551162
) as f:
11561163
json.dump(mse_results_dict, f, ensure_ascii=False, indent=4)
11571164

1158-
*listed_items, summary = eval_results
1159-
1160-
self._logger.info("\n" + summary)
1161-
1162-
return {"AP": listed_items[0] * 100, "AP50": listed_items[1] * 100}, overall_mse
1165+
return {
1166+
"AP": listed_items[0] * 100,
1167+
"AP50": listed_items[1] * 100,
1168+
}, overall_mse
1169+
else:
1170+
return {"AP": listed_items[0] * 100, "AP50": listed_items[1] * 100}
11631171

11641172
def _convert_to_coco_format(self, outputs, info_imgs, ids):
11651173
# reference : yolox > evaluators > coco_evaluator > convert_to_coco_format
@@ -1363,7 +1371,11 @@ def results(self, save_path: str = None):
13631371

13641372
self.write_results(eval_results)
13651373

1366-
overall_mse = None
1374+
# item_keys = list(eval_results.keys())
1375+
item_vals = list(eval_results.values())
1376+
1377+
# self._logger.info("\n" + summary)
1378+
13671379
if self._mse_results:
13681380
mse_results_dict = {"per_frame_mse": self._mse_results}
13691381
overall_mse = 0.0
@@ -1392,12 +1404,9 @@ def results(self, save_path: str = None):
13921404
) as f:
13931405
json.dump(mse_results_dict, f, ensure_ascii=False, indent=4)
13941406

1395-
# item_keys = list(eval_results.keys())
1396-
item_vals = list(eval_results.values())
1397-
1398-
# self._logger.info("\n" + summary)
1399-
1400-
return {"AP": item_vals[0] * 100, "AP50": item_vals[1] * 100}, overall_mse
1407+
return {"AP": item_vals[0] * 100, "AP50": item_vals[1] * 100}, overall_mse
1408+
else:
1409+
return {"AP": item_vals[0] * 100, "AP50": item_vals[1] * 100}
14011410

14021411

14031412
@register_evaluator("VISUAL-QUALITY-EVAL")

compressai_vision/run/eval_split_inference.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,9 @@ def main(conf: DictConfig):
267267
"avg_bpp": avg_bpp,
268268
"end_accuracy": performance,
269269
**elap_times,
270-
"inv_mse": None if mse is None else 1.0 / mse,
270+
**(
271+
{"inv_mse": 0 if mse == 0 else 1.0 / mse} if mse is not None else {}
272+
),
271273
}
272274
)
273275
print(tabulate(result_df, headers="keys", tablefmt="psql"))
@@ -283,7 +285,9 @@ def main(conf: DictConfig):
283285
"bitrate (kbps)": bitrate,
284286
"end_accuracy": performance,
285287
**elap_times,
286-
"inv_mse": None if mse is None else 1.0 / mse,
288+
**(
289+
{"inv_mse": 0 if mse == 0 else 1.0 / mse} if mse is not None else {}
290+
),
287291
}
288292
)
289293
print(tabulate(result_df, headers="keys", tablefmt="psql"))

0 commit comments

Comments
 (0)