@@ -101,7 +101,7 @@ def export(self):
101101 )
102102
103103
104- def cal_image_matrics (
104+ def cal_metrics (
105105 sheet_name : str = "results" ,
106106 txt_path : str = "" ,
107107 to_append : bool = True ,
@@ -112,7 +112,6 @@ def cal_image_matrics(
112112 metrics_npy_path : str = "./metrics.npy" ,
113113 num_bits : int = 3 ,
114114 num_workers : int = 2 ,
115- ncols_tqdm : int = 79 ,
116115 metric_names : tuple = ("sm" , "wfm" , "mae" , "fmeasure" , "em" ),
117116):
118117 """Save the results of all models on different datasets in a `npy` file in the form of a
@@ -129,7 +128,6 @@ def cal_image_matrics(
129128 metrics_npy_path (str, optional): The npy file path for saving metric values. Defaults to "./metrics.npy".
130129 num_bits (int, optional): The number of bits used to format results. Defaults to 3.
131130 num_workers (int, optional): The number of workers of multiprocessing or multithreading. Defaults to 2.
132- ncols_tqdm (int, optional): Number of columns for tqdm. Defaults to 79.
133131 metric_names (tuple, optional): Names of metrics. Defaults to ("sm", "wfm", "mae", "fmeasure", "em").
134132
135133 Returns:
@@ -167,16 +165,12 @@ def cal_image_matrics(
167165 sheet_name = sheet_name ,
168166 )
169167
170- # multi-process mode
171- # tqdm.set_lock(RLock())
172- # pool_cls = pool.Pool
173- # multi-threading mode
174168 tqdm .set_lock (TRLock ())
175- pool_cls = pool .ThreadPool
176- procs = pool_cls (processes = num_workers , initializer = tqdm .set_lock , initargs = (tqdm .get_lock (),))
169+ procs = pool .ThreadPool (
170+ processes = num_workers , initializer = tqdm .set_lock , initargs = (tqdm .get_lock (),)
171+ )
177172 print (f"Create a { procs } )." )
178173
179- procs_idx = 0
180174 for dataset_name , dataset_path in datasets_info .items ():
181175 # 获取真值图片信息
182176 gt_info = dataset_path ["mask" ]
@@ -187,16 +181,13 @@ def cal_image_matrics(
187181 gt_index_file = dataset_path .get ("index_file" )
188182 if gt_index_file :
189183 gt_name_list = get_name_list (
190- data_path = gt_index_file ,
191- name_prefix = gt_prefix ,
192- name_suffix = gt_suffix ,
184+ data_path = gt_index_file , name_prefix = gt_prefix , name_suffix = gt_suffix
193185 )
194186 else :
195187 gt_name_list = get_name_list (
196- data_path = gt_root ,
197- name_prefix = gt_prefix ,
198- name_suffix = gt_suffix ,
188+ data_path = gt_root , name_prefix = gt_prefix , name_suffix = gt_suffix
199189 )
190+ gt_info_pair = (gt_root , gt_prefix , gt_suffix )
200191 assert len (gt_name_list ) > 0 , "there is not ground truth."
201192
202193 # ==>> test the intersection between pre and gt for each method <<==
@@ -214,33 +205,28 @@ def cal_image_matrics(
214205 pre_name_list = get_name_list (
215206 data_path = pre_root , name_prefix = pre_prefix , name_suffix = pre_suffix
216207 )
208+ pre_info_pair = (pre_root , pre_prefix , pre_suffix )
217209
218210 # get the intersection
219- eval_name_list = sorted (list ( set (gt_name_list ).intersection (pre_name_list ) ))
211+ eval_name_list = sorted (set (gt_name_list ).intersection (pre_name_list ))
220212 if len (eval_name_list ) == 0 :
221213 tqdm .write (f"{ method_name } does not have results on { dataset_name } " )
222214 continue
223215
216+ desc = f"[{ dataset_name } ({ len (gt_name_list )} ):{ method_name } ({ len (pre_name_list )} )]"
224217 kwargs = dict (
225218 names = eval_name_list ,
226219 num_bits = num_bits ,
227- pre_root = pre_root ,
228- pre_prefix = pre_prefix ,
229- pre_suffix = pre_suffix ,
230- gt_root = gt_root ,
231- gt_prefix = gt_prefix ,
232- gt_suffix = gt_suffix ,
233- desc = f"[{ dataset_name } ({ len (gt_name_list )} ):{ method_name } ({ len (pre_name_list )} )]" ,
234- proc_idx = procs_idx ,
220+ pre_info_pair = pre_info_pair ,
221+ gt_info_pair = gt_info_pair ,
235222 metric_names = metric_names ,
236- ncols_tqdm = ncols_tqdm ,
237223 metric_class = metric_class ,
224+ desc = desc ,
238225 )
239226 callback = partial (recorder .record , dataset_name = dataset_name , method_name = method_name )
240227 procs .apply_async (func = evaluate , kwds = kwargs , callback = callback )
241- # for debugging
228+ # print(" -------------------- [DEBUG] -------------------- ")
242229 # callback(evaluate(**kwargs), dataset_name=dataset_name, method_name=method_name)
243- procs_idx += 1
244230 procs .close ()
245231 procs .join ()
246232
@@ -257,38 +243,22 @@ def cal_image_matrics(
257243 tqdm .write (f"All methods have been evaluated:\n { formatted_string } " )
258244
259245
260- def evaluate (
261- names ,
262- num_bits ,
263- gt_root ,
264- gt_prefix ,
265- gt_suffix ,
266- pre_root ,
267- pre_prefix ,
268- pre_suffix ,
269- metric_class ,
270- desc = "" ,
271- proc_idx = None ,
272- metric_names = None ,
273- ncols_tqdm = 79 ,
274- ):
246+ def evaluate (names , num_bits , pre_info_pair , gt_info_pair , metric_class , metric_names , desc = "" ):
275247 metric_recoder = metric_class (metric_names = metric_names )
276248 # https://github.com/tqdm/tqdm#parameters
277249 # https://github.com/tqdm/tqdm/blob/master/examples/parallel_bars.py
278- for name in tqdm (
279- names , total = len (names ), desc = desc , position = proc_idx , ncols = ncols_tqdm , lock_args = (False ,)
280- ):
250+ for name in tqdm (names , total = len (names ), desc = desc , ncols = 79 , lock_args = (False ,)):
281251 gt , pre = get_gt_pre_with_name (
282252 img_name = name ,
283- pre_root = pre_root ,
284- pre_prefix = pre_prefix ,
285- pre_suffix = pre_suffix ,
286- gt_root = gt_root ,
287- gt_prefix = gt_prefix ,
288- gt_suffix = gt_suffix ,
253+ pre_root = pre_info_pair [ 0 ] ,
254+ pre_prefix = pre_info_pair [ 1 ] ,
255+ pre_suffix = pre_info_pair [ 2 ] ,
256+ gt_root = gt_info_pair [ 0 ] ,
257+ gt_prefix = gt_info_pair [ 1 ] ,
258+ gt_suffix = gt_info_pair [ 2 ] ,
289259 to_normalize = False ,
290260 )
291- metric_recoder .step (pre = pre , gt = gt , gt_path = os .path .join (gt_root , name ))
261+ metric_recoder .step (pre = pre , gt = gt , gt_path = os .path .join (gt_info_pair [ 0 ] , name ))
292262
293263 method_results = metric_recoder .show (num_bits = num_bits , return_ndarray = False )
294264 return method_results
0 commit comments