Skip to content

Commit c803430

Browse files
committed
更新图像和视频的评估代码,以及一些小的修复。
1 parent a8ae171 commit c803430

File tree

8 files changed

+274
-274
lines changed

8 files changed

+274
-274
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ gen
280280
/configs/
281281
# /*.py
282282
/*.sh
283+
/*.ps1
284+
/*.bat
283285
/results/rgb_sod.md
284286
/results/htmls/*.html
285287
!/.github/assets/*.jpg

eval.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import textwrap
55
import warnings
66

7-
from metrics import cal_sod_matrics
7+
from metrics import image_metrics, video_metrics
88
from utils.generate_info import get_datasets_info, get_methods_info
99
from utils.recorders import SUPPORTED_METRICS
1010

@@ -39,7 +39,7 @@ def get_args():
3939
4040
EXAMPLES:
4141
42-
python eval_image.py \
42+
python eval.py \
4343
--dataset-json configs/datasets/rgbd_sod.json \
4444
--method-json \
4545
configs/methods/json/rgbd_other_methods.json \
@@ -113,12 +113,41 @@ def get_args():
113113
"--metric-names",
114114
type=str,
115115
nargs="+",
116-
default=["mae", "fmeasure", "precision", "recall", "em", "sm", "wfm"],
116+
default=["sm", "wfm", "mae", "fmeasure", "em", "precision", "recall", "msiou"],
117117
choices=SUPPORTED_METRICS,
118118
help="Names of metrics",
119119
)
120+
parser.add_argument(
121+
"--data-type",
122+
type=str,
123+
default="image",
124+
choices=["image", "video"],
125+
help="Type of data.",
126+
)
127+
128+
known_args = parser.parse_known_args()[0]
129+
if known_args.data_type == "video":
130+
parser.add_argument(
131+
"--valid-frame-start",
132+
type=int,
133+
default=0,
134+
help="Valid start index of the frame in each gt video. Defaults to 1, it will skip the first frame. If it is set to None, the code will not skip frames.",
135+
)
136+
parser.add_argument(
137+
"--valid-frame-end",
138+
type=int,
139+
default=0,
140+
help="Valid end index of the frame in each gt video. Defaults to -1, it will skip the last frame. If it is set to 0, the code will not skip frames.",
141+
)
142+
120143
args = parser.parse_args()
121144

145+
if args.data_type == "video":
146+
args.valid_frame_start = max(args.valid_frame_start, 0)
147+
args.valid_frame_end = min(args.valid_frame_end, 0)
148+
if args.valid_frame_end == 0:
149+
args.valid_frame_end = None
150+
122151
if args.metric_npy:
123152
os.makedirs(os.path.dirname(args.metric_npy), exist_ok=True)
124153
if args.curves_npy:
@@ -149,22 +178,39 @@ def main():
149178
exclude_methods=args.exclude_methods,
150179
)
151180

152-
# 确保多进程在windows上也可以正常使用
153-
cal_sod_matrics.cal_image_matrics(
154-
sheet_name="Results",
155-
to_append=not args.to_overwrite,
156-
txt_path=args.record_txt,
157-
xlsx_path=args.record_xlsx,
158-
methods_info=methods_info,
159-
datasets_info=datasets_info,
160-
curves_npy_path=args.curves_npy,
161-
metrics_npy_path=args.metric_npy,
162-
num_bits=args.num_bits,
163-
num_workers=args.num_workers,
164-
metric_names=args.metric_names,
165-
ncols_tqdm=119,
166-
)
167-
168-
181+
if args.data_type == "image":
182+
image_metrics.cal_metrics(
183+
sheet_name="Results",
184+
to_append=not args.to_overwrite,
185+
txt_path=args.record_txt,
186+
xlsx_path=args.record_xlsx,
187+
methods_info=methods_info,
188+
datasets_info=datasets_info,
189+
curves_npy_path=args.curves_npy,
190+
metrics_npy_path=args.metric_npy,
191+
num_bits=args.num_bits,
192+
num_workers=args.num_workers,
193+
metric_names=args.metric_names,
194+
)
195+
else:
196+
video_metrics.cal_metrics(
197+
sheet_name="Results",
198+
to_append=not args.to_overwrite,
199+
txt_path=args.record_txt,
200+
xlsx_path=args.record_xlsx,
201+
methods_info=methods_info,
202+
datasets_info=datasets_info,
203+
curves_npy_path=args.curves_npy,
204+
metrics_npy_path=args.metric_npy,
205+
num_bits=args.num_bits,
206+
num_workers=args.num_workers,
207+
metric_names=args.metric_names,
208+
return_group=False,
209+
start_idx=args.valid_frame_start,
210+
end_idx=args.valid_frame_end,
211+
)
212+
213+
214+
# 确保多进程在windows上也可以正常使用
169215
if __name__ == "__main__":
170216
main()

metrics/cal_sod_matrics.py renamed to metrics/image_metrics.py

Lines changed: 23 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def export(self):
101101
)
102102

103103

104-
def cal_image_matrics(
104+
def cal_metrics(
105105
sheet_name: str = "results",
106106
txt_path: str = "",
107107
to_append: bool = True,
@@ -112,7 +112,6 @@ def cal_image_matrics(
112112
metrics_npy_path: str = "./metrics.npy",
113113
num_bits: int = 3,
114114
num_workers: int = 2,
115-
ncols_tqdm: int = 79,
116115
metric_names: tuple = ("sm", "wfm", "mae", "fmeasure", "em"),
117116
):
118117
"""Save the results of all models on different datasets in a `npy` file in the form of a
@@ -129,7 +128,6 @@ def cal_image_matrics(
129128
metrics_npy_path (str, optional): The npy file path for saving metric values. Defaults to "./metrics.npy".
130129
num_bits (int, optional): The number of bits used to format results. Defaults to 3.
131130
num_workers (int, optional): The number of workers of multiprocessing or multithreading. Defaults to 2.
132-
ncols_tqdm (int, optional): Number of columns for tqdm. Defaults to 79.
133131
metric_names (tuple, optional): Names of metrics. Defaults to ("sm", "wfm", "mae", "fmeasure", "em").
134132
135133
Returns:
@@ -167,16 +165,12 @@ def cal_image_matrics(
167165
sheet_name=sheet_name,
168166
)
169167

170-
# multi-process mode
171-
# tqdm.set_lock(RLock())
172-
# pool_cls = pool.Pool
173-
# multi-threading mode
174168
tqdm.set_lock(TRLock())
175-
pool_cls = pool.ThreadPool
176-
procs = pool_cls(processes=num_workers, initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),))
169+
procs = pool.ThreadPool(
170+
processes=num_workers, initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),)
171+
)
177172
print(f"Create a {procs}).")
178173

179-
procs_idx = 0
180174
for dataset_name, dataset_path in datasets_info.items():
181175
# 获取真值图片信息
182176
gt_info = dataset_path["mask"]
@@ -187,16 +181,13 @@ def cal_image_matrics(
187181
gt_index_file = dataset_path.get("index_file")
188182
if gt_index_file:
189183
gt_name_list = get_name_list(
190-
data_path=gt_index_file,
191-
name_prefix=gt_prefix,
192-
name_suffix=gt_suffix,
184+
data_path=gt_index_file, name_prefix=gt_prefix, name_suffix=gt_suffix
193185
)
194186
else:
195187
gt_name_list = get_name_list(
196-
data_path=gt_root,
197-
name_prefix=gt_prefix,
198-
name_suffix=gt_suffix,
188+
data_path=gt_root, name_prefix=gt_prefix, name_suffix=gt_suffix
199189
)
190+
gt_info_pair = (gt_root, gt_prefix, gt_suffix)
200191
assert len(gt_name_list) > 0, "there is not ground truth."
201192

202193
# ==>> test the intersection between pre and gt for each method <<==
@@ -214,33 +205,28 @@ def cal_image_matrics(
214205
pre_name_list = get_name_list(
215206
data_path=pre_root, name_prefix=pre_prefix, name_suffix=pre_suffix
216207
)
208+
pre_info_pair = (pre_root, pre_prefix, pre_suffix)
217209

218210
# get the intersection
219-
eval_name_list = sorted(list(set(gt_name_list).intersection(pre_name_list)))
211+
eval_name_list = sorted(set(gt_name_list).intersection(pre_name_list))
220212
if len(eval_name_list) == 0:
221213
tqdm.write(f"{method_name} does not have results on {dataset_name}")
222214
continue
223215

216+
desc = f"[{dataset_name}({len(gt_name_list)}):{method_name}({len(pre_name_list)})]"
224217
kwargs = dict(
225218
names=eval_name_list,
226219
num_bits=num_bits,
227-
pre_root=pre_root,
228-
pre_prefix=pre_prefix,
229-
pre_suffix=pre_suffix,
230-
gt_root=gt_root,
231-
gt_prefix=gt_prefix,
232-
gt_suffix=gt_suffix,
233-
desc=f"[{dataset_name}({len(gt_name_list)}):{method_name}({len(pre_name_list)})]",
234-
proc_idx=procs_idx,
220+
pre_info_pair=pre_info_pair,
221+
gt_info_pair=gt_info_pair,
235222
metric_names=metric_names,
236-
ncols_tqdm=ncols_tqdm,
237223
metric_class=metric_class,
224+
desc=desc,
238225
)
239226
callback = partial(recorder.record, dataset_name=dataset_name, method_name=method_name)
240227
procs.apply_async(func=evaluate, kwds=kwargs, callback=callback)
241-
# for debugging
228+
# print(" -------------------- [DEBUG] -------------------- ")
242229
# callback(evaluate(**kwargs), dataset_name=dataset_name, method_name=method_name)
243-
procs_idx += 1
244230
procs.close()
245231
procs.join()
246232

@@ -257,38 +243,22 @@ def cal_image_matrics(
257243
tqdm.write(f"All methods have been evaluated:\n{formatted_string}")
258244

259245

260-
def evaluate(
261-
names,
262-
num_bits,
263-
gt_root,
264-
gt_prefix,
265-
gt_suffix,
266-
pre_root,
267-
pre_prefix,
268-
pre_suffix,
269-
metric_class,
270-
desc="",
271-
proc_idx=None,
272-
metric_names=None,
273-
ncols_tqdm=79,
274-
):
246+
def evaluate(names, num_bits, pre_info_pair, gt_info_pair, metric_class, metric_names, desc=""):
275247
metric_recoder = metric_class(metric_names=metric_names)
276248
# https://github.com/tqdm/tqdm#parameters
277249
# https://github.com/tqdm/tqdm/blob/master/examples/parallel_bars.py
278-
for name in tqdm(
279-
names, total=len(names), desc=desc, position=proc_idx, ncols=ncols_tqdm, lock_args=(False,)
280-
):
250+
for name in tqdm(names, total=len(names), desc=desc, ncols=79, lock_args=(False,)):
281251
gt, pre = get_gt_pre_with_name(
282252
img_name=name,
283-
pre_root=pre_root,
284-
pre_prefix=pre_prefix,
285-
pre_suffix=pre_suffix,
286-
gt_root=gt_root,
287-
gt_prefix=gt_prefix,
288-
gt_suffix=gt_suffix,
253+
pre_root=pre_info_pair[0],
254+
pre_prefix=pre_info_pair[1],
255+
pre_suffix=pre_info_pair[2],
256+
gt_root=gt_info_pair[0],
257+
gt_prefix=gt_info_pair[1],
258+
gt_suffix=gt_info_pair[2],
289259
to_normalize=False,
290260
)
291-
metric_recoder.step(pre=pre, gt=gt, gt_path=os.path.join(gt_root, name))
261+
metric_recoder.step(pre=pre, gt=gt, gt_path=os.path.join(gt_info_pair[0], name))
292262

293263
method_results = metric_recoder.show(num_bits=num_bits, return_ndarray=False)
294264
return method_results

0 commit comments

Comments
 (0)