|
36 | 36 |
|
37 | 37 | from __future__ import annotations |
38 | 38 |
|
39 | | -import argparse |
40 | | -import csv |
41 | | - |
42 | | -from typing import Any, Dict, List |
| 39 | +from typing import Dict |
43 | 40 |
|
44 | 41 | import motmetrics as mm |
45 | 42 | import torch |
46 | 43 | import utils |
47 | 44 |
|
48 | | -from compressai_vision.evaluators.evaluators import BaseEvaluator, MOT_JDE_Eval |
49 | | - |
50 | | -CLASSES = ["TVD", "HIEVE-1080P", "HIEVE-720P"] |
51 | | - |
52 | | -SEQS_BY_CLASS = { |
53 | | - CLASSES[0]: ["TVD-01", "TVD-02", "TVD-03"], |
54 | | - CLASSES[1]: ["HIEVE-13", "HIEVE-16"], |
55 | | - CLASSES[2]: ["HIEVE-2", "HIEVE-17", "HIEVE-18"], |
56 | | -} |
| 45 | +from compressai_vision.evaluators.evaluators import MOT_JDE_Eval |
57 | 46 |
|
58 | 47 |
|
59 | 48 | def get_accumulator_res_for_tvd(item: Dict): |
@@ -103,87 +92,4 @@ def compute_overall_mota(class_name, items): |
103 | 92 | metrics=mm.metrics.motchallenge_metrics, |
104 | 93 | generate_overall=True, |
105 | 94 | ) |
106 | | - # rendered_summary = mm.io.render_summary( |
107 | | - # summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names |
108 | | - # ) |
109 | | - |
110 | | - # print("\n\n") |
111 | | - # print(rendered_summary) |
112 | | - # print("\n") |
113 | | - |
114 | | - # names.append("Overall") |
115 | 95 | return summary, names |
116 | | - |
117 | | - |
118 | | -if __name__ == "__main__": |
119 | | - parser = argparse.ArgumentParser() |
120 | | - |
121 | | - parser.add_argument( |
122 | | - "-r", |
123 | | - "--result_path", |
124 | | - required=True, |
125 | | - help="For example, '.../logs/runs/[pipeline]/[codec]/[datacatalog]/' ", |
126 | | - ) |
127 | | - parser.add_argument( |
128 | | - "-q", |
129 | | - "--quality_index", |
130 | | - required=False, |
131 | | - default=-1, |
132 | | - type=int, |
133 | | - help="Provide index of quality folders under the `result_path'. quality_index is only meant to point the orderd folders by qp names because there might be different range of qps are used for different sequences", |
134 | | - ) |
135 | | - parser.add_argument( |
136 | | - "-a", |
137 | | - "--all_qualities", |
138 | | - action="store_true", |
139 | | - help="run all 6 rate points in MPEG CTCs", |
140 | | - ) |
141 | | - parser.add_argument( |
142 | | - "-d", |
143 | | - "--dataset_path", |
144 | | - required=True, |
145 | | - help="For example, '.../vcm_testdata/[dataset]' ", |
146 | | - ) |
147 | | - parser.add_argument( |
148 | | - "-c", |
149 | | - "--class_to_compute", |
150 | | - type=str, |
151 | | - choices=CLASSES, |
152 | | - required=True, |
153 | | - ) |
154 | | - |
155 | | - args = parser.parse_args() |
156 | | - if args.all_qualities: |
157 | | - qualities = range(0, 6) |
158 | | - else: |
159 | | - qualities = [args.quality_index] |
160 | | - |
161 | | - with open( |
162 | | - f"{args.result_path}/{args.class_to_compute}.csv", "w", newline="" |
163 | | - ) as file: |
164 | | - writer = csv.writer(file) |
165 | | - for q in qualities: |
166 | | - items = utils.search_items( |
167 | | - args.result_path, |
168 | | - args.dataset_path, |
169 | | - q, |
170 | | - SEQS_BY_CLASS[args.class_to_compute], |
171 | | - BaseEvaluator.get_jde_eval_info_name, |
172 | | - ) |
173 | | - |
174 | | - assert ( |
175 | | - len(items) > 0 |
176 | | - ), "Nothing relevant information found from given directories..." |
177 | | - |
178 | | - summary, names = compute_overall_mota(args.class_to_compute, items) |
179 | | - |
180 | | - motas = [100.0 * sv[13] for sv in summary.values] |
181 | | - |
182 | | - print(f"{'=' * 10} FINAL OVERALL MOTA SUMMARY {'=' * 10}") |
183 | | - print(f"{'-' * 35} : MOTA") |
184 | | - |
185 | | - for key, val in zip(names, motas): |
186 | | - print(f"{str(key):35} : {val:.4f}%") |
187 | | - if key == "Overall": |
188 | | - writer.writerow([str(q), f"{val:.4f}"]) |
189 | | - print("\n") |
0 commit comments