forked from thuanz123/realfill
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmarks.py
More file actions
1944 lines (1731 loc) · 89.4 KB
/
benchmarks.py
File metadata and controls
1944 lines (1731 loc) · 89.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# benchmarks.py
import argparse
import json
import multiprocessing
import multiprocessing.queues
import os
import re
import subprocess
import sys
import time
import traceback
from collections import OrderedDict, defaultdict
from pathlib import Path
import numpy as np
import pandas as pd
from rich.console import Console, Group
# Rich library imports for enhanced terminal UI
from rich.live import Live
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from tqdm import tqdm
# --- Configuration & Constants ---
METRICS_CONFIG = OrderedDict(
[
("PSNR", ("psnr_metric.py", True)),
("SSIM", ("ssim_metric.py", True)),
("LPIPS", ("lpips_metric.py", False)),
("DreamSim", ("dreamsim_metric.py", False)),
("DINO", ("dino_metric.py", True)),
("CLIP", ("clip_metric.py", True)),
]
)
LOFTR_FILTER_RATES = [0.0, 0.25, 0.50, 0.75]
DEFAULT_NUM_IMAGES = 16
MASTER_CACHE_FILENAME = "master_results_cache.json"
PER_IMAGE_CACHE_BASE = "per_scene_cache"
LOFTR_RANKING_FILENAME = "loftr_ranking_scores.json"
# --- Helper Functions ---
def find_gt_mask_paths(
results_dir_path: Path, dataset_dirs_map: dict
) -> tuple[str | None, str | None]:
"""Finds Ground Truth (GT) and Mask paths based on a results folder name."""
dir_name = results_dir_path.name
# Regex to match folder names like "RealBench-0-results" or "Custom-123"
match = re.match(r"^(RealBench|Custom)-(\d+)(-results.*)?$", dir_name)
if not match:
return None, None
benchmark_type, scene_number_str = match.group(1), match.group(2)
dataset_base_dir = dataset_dirs_map.get(benchmark_type)
if not dataset_base_dir or not Path(dataset_base_dir).is_dir():
# This print is okay, as it happens before any Rich Live display usually
print(
f"Warning: Dataset base directory for '{benchmark_type}' not found or invalid: {dataset_base_dir}"
)
return None, None
# Construct path based on expected structure: dataset_base_dir / BenchmarkType / SceneNumber / target / gt.png (or mask.png)
base_path = Path(dataset_base_dir) / benchmark_type / scene_number_str / "target"
gt_path = base_path / "gt.png"
mask_path = base_path / "mask.png"
if not gt_path.is_file() or not mask_path.is_file():
# Fallback or detailed logging if paths are not found
# print(f"Debug: Checked GT '{gt_path}' and Mask '{mask_path}' - Not found.")
return None, None
return str(gt_path), str(mask_path)
def count_result_images(folder_path: Path) -> int:
"""Counts the number of image files (e.g., '0.png', '1.png') in a folder."""
if not folder_path.is_dir():
return 0
image_count = 0
try:
for item in folder_path.iterdir():
# Check if the file is a .png and its stem is purely digits
if item.is_file() and item.suffix.lower() == ".png" and item.stem.isdigit():
image_count += 1
return image_count
except OSError as e:
# This print is okay, generally called before Rich Live context
print(f"Warning: Could not count images in {folder_path}: {e}")
return 0
def parse_final_score(stdout_str: str) -> float | None:
"""Parses the 'FINAL_SCORE:' line from a metric script's standard output."""
if not isinstance(stdout_str, str):
return None
for line in stdout_str.splitlines():
if line.startswith("FINAL_SCORE:"):
score_part = line.split(":", 1)[1].strip()
if score_part == "ERROR": # Explicit error reported by script
return None
try:
return float(score_part)
except ValueError:
# This print is from within a worker's subprocess call, will appear on worker's stdout.
# Rich will attempt to draw around it.
error_print_func = tqdm.write if sys.stdout.isatty() else print
error_print_func(f"Warning: Could not parse score from FINAL_SCORE line: '{line}'")
return None
return None # FINAL_SCORE line not found
def load_json_cache(file_path: Path | str) -> dict | None:
"""Safely loads JSON data from a file."""
resolved_path = Path(file_path)
if not resolved_path.is_file():
return None
try:
with open(resolved_path, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, OSError, TypeError) as e:
# This print is from the main process, usually before Rich Live.
print(f"Warning: Cache load failed for {resolved_path}: {e}")
return None
def save_json_cache(data: dict, file_path: Path | str):
"""Safely saves data to a JSON file, handling numpy types for serialization."""
resolved_path = Path(file_path)
try:
resolved_path.parent.mkdir(parents=True, exist_ok=True)
serializable_data = convert_numpy_types(data) # Ensure data is JSON serializable
with open(resolved_path, "w", encoding="utf-8") as f:
json.dump(serializable_data, f, indent=4)
except (OSError, TypeError) as e:
# This print is from the main process.
print(f"Error saving JSON cache to {resolved_path}: {e}")
if isinstance(e, TypeError):
# Log a snippet of the data that caused the TypeError for debugging
print(f"Problematic data snippet (first 500 chars): {str(data)[:500]}")
def convert_numpy_types(obj):
"""Recursively converts numpy data types to native Python types for JSON serialization."""
if isinstance(obj, dict):
return {k: convert_numpy_types(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_numpy_types(elem) for elem in obj]
elif isinstance(
obj,
(
np.integer,
np.int_,
np.intc,
np.intp,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
),
): # More comprehensive integer check
return int(obj)
elif isinstance(
obj, (np.floating, np.float16, np.float32, np.float64)
): # More comprehensive float check
if np.isnan(obj):
return None # Represent NaN as null in JSON, as JSON standard does not support NaN
elif np.isinf(obj):
# Represent Inf as null or a very large number string if contextually appropriate. Null is safer.
return None
return float(obj)
elif isinstance(obj, np.ndarray):
return convert_numpy_types(obj.tolist()) # Convert numpy arrays to lists
elif isinstance(obj, (np.bool_, bool)): # Handles numpy bool_ and python bool
return bool(obj)
elif isinstance(obj, np.void): # For structured arrays or void types
return None
return obj # Return the object itself if not a recognized numpy type
def get_scene_key(folder_name: str) -> str | None:
"""Extracts a scene key (e.g., 'RealBench-0') from a folder name."""
match = re.match(r"^(RealBench|Custom)-(\d+)(-results.*)?$", folder_name)
if match:
return f"{match.group(1)}-{match.group(2)}"
return None
def run_metric_script_parallel(
metric_name: str,
script_filename: str,
gt_path: str,
mask_path: str,
results_dir_str: str,
cache_dir_str: str,
num_images: int,
) -> tuple[str, float | None, str]:
"""
Wrapper to run a single metric script as a subprocess.
Returns (metric_name, score, folder_name).
This function is executed within a worker process.
"""
results_path = Path(results_dir_str)
folder_name = results_path.name
# Determine the directory of the metric scripts ('benchmark' subdirectory)
# Assumes this script (benchmarks.py) is in the parent directory of 'benchmark/'
current_script_dir = Path(__file__).parent
benchmark_scripts_dir = current_script_dir / "benchmark"
absolute_script_path = benchmark_scripts_dir / script_filename
# tqdm.write is used for worker messages to minimize interference with Rich in the main process.
# If not in a TTY, it falls back to print.
worker_log_func = tqdm.write if sys.stdout.isatty() else print
if not absolute_script_path.is_file():
worker_log_func(
f"WORKER_ERROR ({metric_name} on {folder_name}): Script not found at {absolute_script_path}"
)
return metric_name, None, folder_name
command = [
sys.executable, # Use the same Python interpreter
str(absolute_script_path),
"--gt_path",
gt_path,
"--mask_path",
mask_path,
"--results_dir",
str(results_path),
"--cache_dir",
cache_dir_str,
"--num_images",
str(num_images),
]
try:
process = subprocess.run(
command,
capture_output=True,
text=True,
check=False,
timeout=720, # 12-minute timeout per metric script
cwd=benchmark_scripts_dir, # Run script from its directory to handle relative paths within it
encoding="utf-8",
errors="replace",
)
stdout_content = process.stdout or ""
stderr_content = process.stderr or ""
if process.returncode != 0:
stderr_suffix = stderr_content.strip()[-500:] # Get last 500 chars of stderr
worker_log_func(
f"\nWORKER_ERROR ({metric_name} on {folder_name}, RC:{process.returncode}): Subprocess failed.\n"
# f" Command: {' '.join(command)}\n" # Potentially too verbose for regular logging
f" Script Dir: {benchmark_scripts_dir}\n"
f" Stderr (last 500 chars): ...{stderr_suffix}\n"
)
return metric_name, None, folder_name
score = parse_final_score(stdout_content)
if (
score is None and "FINAL_SCORE:" in stdout_content and "ERROR" not in stdout_content
): # Check if parsing failed but line was present
stdout_suffix = stdout_content.strip()[-500:]
worker_log_func(
f"\nWORKER_WARN ({metric_name} on {folder_name}, RC:0): Score parsing failed.\n"
f" Stdout (last 500 chars): ...{stdout_suffix}\n"
)
elif score is None and "FINAL_SCORE:" not in stdout_content: # FINAL_SCORE line missing
worker_log_func(
f"\nWORKER_WARN ({metric_name} on {folder_name}, RC:0): FINAL_SCORE line missing in output.\n"
)
return metric_name, score, folder_name
except subprocess.TimeoutExpired:
worker_log_func(
f"\nWORKER_ERROR ({metric_name} on {folder_name}): Timeout expired for metric script.\n"
)
return metric_name, None, folder_name
except Exception as e:
# Catch any other unexpected errors during subprocess execution
worker_log_func(
f"\nWORKER_CRITICAL ({metric_name} on {folder_name}): Exception during script execution: {type(e).__name__} - {e}\n"
)
traceback.print_exc(file=sys.stderr) # Print full traceback to worker's stderr
return metric_name, None, folder_name
def execute_metric_tasks_for_worker(
metric_name_arg: str,
script_filename_arg: str,
tasks_for_this_metric: list, # List of tuples (gt, mask, results_dir, cache_dir, num_images)
results_queue: multiprocessing.Queue,
):
"""
Worker function executed by a dedicated process for a single metric.
It processes all folder tasks for this specific metric and sends results/progress via queue.
Does NOT create its own Rich/tqdm progress bar for display.
"""
for task_details_tuple in tasks_for_this_metric:
# Unpack arguments for run_metric_script_parallel
# task_details_tuple is (gt_path, mask_path, results_dir_str, cache_dir_str, num_images_for_scene)
# For error reporting if run_metric_script_parallel itself crashes badly before returning folder_name
current_folder_name_for_error_context = Path(task_details_tuple[2]).name
try:
# Prepend metric_name and script_filename to the task_details_tuple
full_args_for_metric_script = (
metric_name_arg,
script_filename_arg,
) + task_details_tuple
metric_name_result, score_result, folder_name_result = run_metric_script_parallel(
*full_args_for_metric_script
)
# Send the actual result (score could be None if metric errored)
results_queue.put(("RESULT", metric_name_result, score_result, folder_name_result))
except Exception as e:
# This is a fallback for critical errors if run_metric_script_parallel itself raises an unhandled exception
# (which it shouldn't, given its own try-except blocks, but defense-in-depth).
# This print goes to the worker's stderr. Rich in main process will try to draw around it.
print(
f"\nWORKER_UNHANDLED_CRASH ({metric_name_arg} on {current_folder_name_for_error_context}): "
f"Unhandled exception in 'run_metric_script_parallel' call: {type(e).__name__} - {e}\n",
file=sys.stderr,
)
traceback.print_exc(file=sys.stderr)
# Send an error result for this task
results_queue.put(
("RESULT", metric_name_arg, None, current_folder_name_for_error_context)
)
# Send a progress update message after each task is attempted (successfully or not)
# The main process will use this to advance the Rich progress bar for this specific metric.
results_queue.put(("PROGRESS_TICK", metric_name_arg))
class BenchmarkRunner:
def __init__(self, cli_args: argparse.Namespace):
self.args = cli_args
self.base_results_dir = Path(cli_args.results_base_dir).resolve()
self.dataset_dirs_map = (
{}
) # Stores mapping like {"RealBench": Path(...), "Custom": Path(...)}
# Validate and store dataset paths from arguments
if cli_args.realbench_dataset_dir:
rb_path = Path(cli_args.realbench_dataset_dir).resolve()
if rb_path.is_dir():
self.dataset_dirs_map["RealBench"] = rb_path
else:
print(f"Warning: RealBench dataset directory not found: {rb_path}")
if cli_args.custom_dataset_dir:
cu_path = Path(cli_args.custom_dataset_dir).resolve()
if cu_path.is_dir():
self.dataset_dirs_map["Custom"] = cu_path
else:
print(f"Warning: Custom dataset directory not found: {cu_path}")
if not self.dataset_dirs_map:
# This error will be caught by the main execution block if raised here
raise ValueError(
"No valid dataset directories were provided or found. Please specify --realbench_dataset_dir or --custom_dataset_dir."
)
self.cache_dir = Path(cli_args.cache_dir).resolve()
self.output_file_path = (
Path(cli_args.output_file).resolve() if cli_args.output_file else None
)
self.num_images_per_scene = cli_args.num_images
# Normalize force_recalc list to lowercase for case-insensitive matching
self.force_recalc_metrics_list = [
metric.lower() for metric in (cli_args.force_recalc or [])
]
# Determine metrics to run: from args or all from METRICS_CONFIG
if cli_args.metrics:
self.metrics_to_run_list = cli_args.metrics
else: # Default to all configured metrics if --metrics is not specified
self.metrics_to_run_list = list(METRICS_CONFIG.keys())
# Resolve LoFTR script path relative to this benchmark script's location
if cli_args.loftr_script_path:
# Path(__file__).parent gives the directory of the current script (benchmarks.py)
candidate_loftr_path = Path(__file__).parent / cli_args.loftr_script_path
if candidate_loftr_path.is_file():
self.loftr_script_path = candidate_loftr_path.resolve()
else:
print(
f"Warning: LoFTR script '{cli_args.loftr_script_path}' (resolved to '{candidate_loftr_path}') "
"not found. LoFTR analysis will be skipped."
)
self.loftr_script_path = None
else:
# Should not happen if default is set in argparse, but good for robustness
print("Warning: LoFTR script path not specified. LoFTR analysis will be skipped.")
self.loftr_script_path = None
self.master_cache_file = self.cache_dir / MASTER_CACHE_FILENAME
self.per_image_cache_root_dir = self.cache_dir / PER_IMAGE_CACHE_BASE
self.discovered_result_folders = [] # List of Path objects for valid result folders
self.master_results_data = defaultdict(dict) # Stores {folder_name: {metric: score}}
self.per_image_scores_cache = defaultdict(
lambda: defaultdict(dict)
) # In-memory cache for per-image scores
self.loftr_ranking_data = defaultdict(
list
) # Stores {folder_name: [ranked_image_filenames]}
# Initial console output (before Rich Live typically starts)
print("Benchmark Runner Initialized:")
print(f" Results Base Directory: {self.base_results_dir}")
print(f" Dataset Mappings: {self.dataset_dirs_map}")
print(f" Cache Directory: {self.cache_dir}")
print(f" Metrics to Evaluate: {', '.join(self.metrics_to_run_list)}")
print(
f" Force Recalculate: {', '.join(self.force_recalc_metrics_list) if self.force_recalc_metrics_list else 'None'}"
)
if self.loftr_script_path:
print(f" LoFTR Script: {self.loftr_script_path}")
if self.output_file_path:
print(f" Report Output File: {self.output_file_path}")
try:
cpu_core_count = os.cpu_count()
print(f" System Info: Detected {cpu_core_count} CPU cores.")
except NotImplementedError:
print(" System Info: Could not detect CPU core count.")
cpu_core_count = None # Unused for Rich strategy but informative
def discover_folders(self):
"""Discovers result folders matching the expected pattern and GT/Mask availability."""
print(f"\nScanning for result folders in {self.base_results_dir}...")
self.discovered_result_folders = [] # Reset if called multiple times
potential_folders_count = 0
skipped_due_to_mapping = 0
skipped_due_to_gt_mask = 0
if not self.base_results_dir.is_dir():
print(f"Error: Base results directory not found: {self.base_results_dir}")
return
# Use tqdm for this initial scan as it's single-threaded and can be long
# Disabling bar if not TTY for cleaner logs
folder_iterator = sorted(self.base_results_dir.iterdir()) # Sort for predictable order
pbar_disabled = not sys.stdout.isatty()
for item_path in tqdm(
folder_iterator, desc="Scanning Result Folders", unit="folder", disable=pbar_disabled
):
if item_path.is_dir() and re.match(
r"^(RealBench|Custom)-\d+(-results.*)?$", item_path.name
):
potential_folders_count += 1
# Extract benchmark type (RealBench or Custom) to check against dataset_dirs_map
type_match = re.match(r"^(RealBench|Custom)", item_path.name)
if type_match:
benchmark_type_from_folder = type_match.group(1)
if benchmark_type_from_folder in self.dataset_dirs_map:
# Check for corresponding GT and Mask files
gt_file, mask_file = find_gt_mask_paths(item_path, self.dataset_dirs_map)
if gt_file and mask_file:
self.discovered_result_folders.append(item_path)
else:
# Use tqdm.write for messages during tqdm loop to avoid breaking the bar
if not pbar_disabled:
tqdm.write(f" Skipping '{item_path.name}': Missing GT/Mask files.")
skipped_due_to_gt_mask += 1
else:
if not pbar_disabled:
tqdm.write(
f" Skipping '{item_path.name}': Dataset type '{benchmark_type_from_folder}' not mapped in dataset_dirs_map."
)
skipped_due_to_mapping += 1
# Summary after scan
print(f"Folder scan complete. Found {potential_folders_count} potential result folders.")
print(
f" - Added {len(self.discovered_result_folders)} folders with valid dataset mapping & GT/Mask files."
)
if skipped_due_to_mapping > 0:
print(
f" - Skipped {skipped_due_to_mapping} folders (dataset type not specified or mapped)."
)
if skipped_due_to_gt_mask > 0:
print(
f" - Skipped {skipped_due_to_gt_mask} folders (missing corresponding GT/Mask files)."
)
def load_master_cache(self):
"""Loads the master results cache file from disk."""
print(f"Loading master results cache: {self.master_cache_file}")
cached_data = load_json_cache(self.master_cache_file)
self.master_results_data = defaultdict(dict) # Reset before loading
if cached_data and isinstance(cached_data, dict):
loaded_entries_count = 0
for folder_name, metrics_dict in cached_data.items():
if isinstance(metrics_dict, dict):
self.master_results_data[folder_name] = metrics_dict
loaded_entries_count += 1
else:
print(
f"Warning: Invalid cache entry format for folder '{folder_name}'. Skipping."
)
print(f"Loaded {loaded_entries_count} folder entries from master cache.")
else:
print("No valid master cache found or cache file is empty/corrupted.")
def save_master_cache(self):
"""Saves the current master results to the cache file."""
# Convert defaultdict to dict for saving, as defaultdict might not be ideal for JSON structure
data_to_save = dict(self.master_results_data)
print(f"Saving master results cache ({len(data_to_save)} entries)...")
save_json_cache(data_to_save, self.master_cache_file)
print(f"Master cache saved to: {self.master_cache_file}")
def check_folder_contents(self, folder_path: Path) -> bool:
"""Checks if the folder contains the expected number of result images (e.g., 0.png to N-1.png)."""
if not folder_path.is_dir():
return False
try:
# Check for existence of each numbered image file
for i in range(self.num_images_per_scene):
if not (folder_path / f"{i}.png").is_file():
return False # Missing at least one expected image
return True # All expected images found
except Exception as e:
# This print occurs during the initial scan, before Rich Live
print(f"Error while checking contents of folder {folder_path}: {e}")
return False
def run_all_metrics(self):
"""
Runs metric calculations for all discovered and valid folders.
Uses Rich library for live progress display.
"""
if not self.discovered_result_folders:
print("No result folders discovered to run metrics on. Skipping metric calculation.")
return
# Initialize Rich Console for consistent output
# All prints from this method onwards should ideally use rich_console.print()
# if they are intended to interact correctly with the Live display.
rich_console = Console()
rich_console.print(
"\n--- Preparing Metric Tasks for Parallel Execution ---", style="bold cyan"
)
self.load_master_cache() # This method uses standard print, okay before Live starts
skipped_incomplete_folders = 0
# tasks_by_metric: dict where key is metric_name, value is list of task_detail_tuples
tasks_to_run_by_metric = defaultdict(list)
folders_processed_count = 0 # Folders that are complete and will be processed
folders_requiring_metric_tasks = set() # Folders that have at least one metric to calculate
# Create a mutable copy of metrics to run, as it might be pruned
# if scripts are missing or metrics are invalid.
# Use valid_metrics_for_run to track metrics that actually have scripts.
valid_metrics_for_run = set(self.metrics_to_run_list)
rich_console.print(
"Checking folder contents and identifying required metric calculations..."
)
benchmark_scripts_dir_base = Path(__file__).parent / "benchmark"
# This initial folder check still uses tqdm because it's a straightforward,
# single-threaded preparation step before the complex multiprocessing with Rich Live.
# tqdm.write is used for messages within this loop.
pbar_disabled = not sys.stdout.isatty()
for folder_path in tqdm(
self.discovered_result_folders,
desc="Validating Folders & Cache",
unit="folder",
disable=pbar_disabled,
):
folder_name = folder_path.name
ground_truth_path, mask_path = find_gt_mask_paths(folder_path, self.dataset_dirs_map)
if not ground_truth_path or not mask_path:
# This should ideally not happen if discover_folders worked correctly
tqdm.write(
f"Warning: Skipping '{folder_name}': GT/Mask path became invalid post-discovery."
)
continue
if not self.check_folder_contents(folder_path):
actual_image_count = count_result_images(
folder_path
) # Recount for accurate message
tqdm.write(
f"Skipping '{folder_name}': Incomplete ({actual_image_count}/{self.num_images_per_scene} images)."
)
skipped_incomplete_folders += 1
if folder_name in self.master_results_data:
tqdm.write(
f" - Removing stale cache entry for incomplete folder '{folder_name}'."
)
del self.master_results_data[folder_name]
continue
folders_processed_count += 1
if folder_name not in self.master_results_data:
self.master_results_data[folder_name] = {} # Initialize if new
# Iterate over a copy of valid_metrics_for_run for safe removal during iteration
for metric_name in list(valid_metrics_for_run):
if metric_name not in METRICS_CONFIG:
tqdm.write(
f"Warning: Metric '{metric_name}' is not in METRICS_CONFIG. Removing from current run."
)
valid_metrics_for_run.discard(metric_name)
if metric_name in self.metrics_to_run_list:
self.metrics_to_run_list.remove(metric_name)
continue
script_filename, _ = METRICS_CONFIG[metric_name]
absolute_script_path = benchmark_scripts_dir_base / script_filename
if not absolute_script_path.is_file():
tqdm.write(
f"Warning: Script '{absolute_script_path}' missing for metric '{metric_name}'. "
f"This metric will be skipped for all folders."
)
valid_metrics_for_run.discard(
metric_name
) # Remove from consideration for this run
if metric_name in self.metrics_to_run_list:
self.metrics_to_run_list.remove(metric_name)
continue # Skip this metric for this folder and subsequent ones
# Determine if recalculation is needed
force_recalculation = (
"all" in self.force_recalc_metrics_list
or metric_name.lower() in self.force_recalc_metrics_list
)
score_is_cached_and_valid = (
metric_name in self.master_results_data.get(folder_name, {})
and self.master_results_data[folder_name][metric_name] is not None
)
if force_recalculation or not score_is_cached_and_valid:
task_details_tuple = (
ground_truth_path,
mask_path,
str(folder_path),
str(self.cache_dir),
self.num_images_per_scene,
)
tasks_to_run_by_metric[metric_name].append(task_details_tuple)
folders_requiring_metric_tasks.add(folder_name)
# Pre-initialize or invalidate stale cache entry for this metric and folder
self.master_results_data[folder_name][metric_name] = None
# Filter out metrics that ended up with no tasks after folder/cache checks
# or were removed because their scripts were missing.
active_tasks_by_metric = {
metric: tasks
for metric, tasks in tasks_to_run_by_metric.items()
if tasks and metric in valid_metrics_for_run
}
total_individual_metric_folder_tasks = sum(
len(tasks) for tasks in active_tasks_by_metric.values()
)
if not total_individual_metric_folder_tasks:
rich_console.print(
"\nNo metric calculations required (all results cached, or folders incomplete/skipped, or no valid metrics)."
)
self.save_master_cache() # Save in case stale entries were removed
rich_console.print("\n--- Metric Execution Phase Skipped ---", style="bold yellow")
rich_console.print(
f"Folders eligible for processing (complete): {folders_processed_count}"
)
if skipped_incomplete_folders > 0:
rich_console.print(
f"Folders skipped due to incompleteness: {skipped_incomplete_folders}"
)
rich_console.print(f"Metric-folder tasks needing calculation: 0")
return
num_dedicated_metric_workers = len(active_tasks_by_metric)
rich_console.print(
f"\n--- Launching {total_individual_metric_folder_tasks} Metric-Folder Tasks "
f"using {num_dedicated_metric_workers} Dedicated Metric Worker Processes ---",
style="bold cyan",
)
final_metrics_being_executed = sorted(list(active_tasks_by_metric.keys()))
rich_console.print(
f"Metrics involved in this run: {', '.join(final_metrics_being_executed)}"
)
rich_console.print(
f"Results will be updated/stored for {len(folders_requiring_metric_tasks)} folders."
)
# --- Multiprocessing and Rich Live Display Setup ---
# Manager().Queue() is generally more robust for complex objects or across different OS
results_queue = multiprocessing.Manager().Queue()
worker_processes_list = []
# Define Rich Progress objects
# Overall progress bar for all individual metric-folder tasks
overall_progress_display = Progress(
TextColumn("Overall Progress:"),
BarColumn(bar_width=None),
"[progress.percentage]{task.percentage:>3.1f}%",
TextColumn("({task.completed} of {task.total} tasks)"),
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
SpinnerColumn(spinner_name="dots"),
console=rich_console,
transient=False, # Keep this bar visible after completion
)
overall_task_id = overall_progress_display.add_task(
"Calculating all metrics", total=total_individual_metric_folder_tasks, start=False
)
# Group for individual metric progress bars
metric_specific_progress_display = Progress(
TextColumn("[bold blue]{task.description}", justify="right"),
BarColumn(bar_width=None),
TextColumn("[progress.percentage]{task.percentage:>3.1f}%"),
TextColumn("({task.completed} of {task.total} folders)"),
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
console=rich_console,
transient=True, # Individual metric bars can disappear when done
)
# Map metric names to their Rich Task IDs
metric_task_ids_map = {}
for metric_name_key, tasks_list_for_metric in active_tasks_by_metric.items():
task_id = metric_specific_progress_display.add_task(
description=f"Metric: {metric_name_key:<10}", # Pad for alignment
total=len(tasks_list_for_metric),
start=False, # Will be started when worker process launches
)
metric_task_ids_map[metric_name_key] = task_id
# Group the progress displays for Rich Live
# Any other Rich renderables (like Tables or Panels for logs) could be added here.
live_display_group = Group(metric_specific_progress_display, overall_progress_display)
# Start the Live display context
# redirect_stdout and redirect_stderr are False by default. This means prints from
# worker processes (like those from tqdm.write in run_metric_script_parallel)
# will print directly to terminal, and Rich Live will redraw around them.
# Setting them to True can capture worker output but might have performance implications
# or require careful handling if workers produce a lot of output.
with Live(
live_display_group,
console=rich_console,
refresh_per_second=12,
vertical_overflow="visible",
) as live:
overall_progress_display.start_task(overall_task_id) # Start the overall counter
# Launch worker processes
for metric_name_to_run, tasks_list in active_tasks_by_metric.items():
if metric_name_to_run in metric_task_ids_map: # Ensure task ID exists
metric_specific_progress_display.start_task(
metric_task_ids_map[metric_name_to_run]
)
script_filename_for_metric, _ = METRICS_CONFIG[metric_name_to_run]
worker_args = (
metric_name_to_run,
script_filename_for_metric,
tasks_list,
results_queue,
)
process = multiprocessing.Process(
target=execute_metric_tasks_for_worker, args=worker_args
)
worker_processes_list.append(process)
process.start()
live.console.print("[green]All metric worker processes launched.[/green]")
# Collect results from the queue
num_results_collected = 0
active_workers_exist = True
while (
num_results_collected < total_individual_metric_folder_tasks
and active_workers_exist
):
try:
# Timeout allows the Live display to refresh and check worker status
message_from_worker = results_queue.get(timeout=0.5) # seconds
msg_type, metric_name_from_msg, *payload_data = message_from_worker
if msg_type == "PROGRESS_TICK":
if metric_name_from_msg in metric_task_ids_map:
metric_specific_progress_display.advance(
metric_task_ids_map[metric_name_from_msg], 1
)
# Note: Overall progress is advanced when a 'RESULT' is processed.
elif msg_type == "RESULT":
score_value, folder_name_from_msg = payload_data
if folder_name_from_msg and metric_name_from_msg:
# Ensure folder entry exists (should, from earlier prep)
if folder_name_from_msg not in self.master_results_data:
live.console.print(
f"[yellow]Warning: Result for unexpected folder '{folder_name_from_msg}'. Initializing entry.[/yellow]"
)
self.master_results_data[folder_name_from_msg] = {}
# Store the result if the metric is still considered valid for this run
if metric_name_from_msg in valid_metrics_for_run:
self.master_results_data[folder_name_from_msg][
metric_name_from_msg
] = score_value
num_results_collected += 1
overall_progress_display.advance(overall_task_id, 1)
# Could add handling for other message types, e.g., explicit log messages from workers
# elif msg_type == 'WORKER_LOG':
# log_level, log_message = payload_data
# live.console.print(f"Worker ({metric_name_from_msg}): {log_message}", style=log_level)
# Expected if queue is empty with timeout
except multiprocessing.queues.Empty:
pass
except Exception as e: # Handle unexpected errors during queue processing
live.console.print(
f"[bold red]Error processing message from worker queue: {e}[/bold red]"
)
# Log traceback to stderr to avoid interfering with Rich display too much
traceback.print_exc(file=sys.stderr)
# Check if all worker processes have exited if we haven't collected all results yet
if num_results_collected < total_individual_metric_folder_tasks:
active_workers_exist = any(p.is_alive() for p in worker_processes_list)
if not active_workers_exist:
live.console.print(
"[bold yellow]Warning: All worker processes have exited, "
"but not all expected results were collected. "
"Check for errors in worker logs (printed above Rich UI or in stderr).[/bold yellow]"
)
# Update overall progress to reflect actual collected items if short
if overall_progress_display.tasks[0].completed < num_results_collected:
overall_progress_display.update(
overall_task_id, completed=num_results_collected
)
break # Exit collection loop
# After collection loop (either completed or broken due to workers finishing early)
# Ensure overall progress reflects the final count
overall_progress_display.update(overall_task_id, completed=num_results_collected)
if num_results_collected < total_individual_metric_folder_tasks:
overall_progress_display.update(
overall_task_id, description="Calculating all metrics (Run Incomplete)"
)
else:
overall_progress_display.update(
overall_task_id, description="Calculating all metrics (Completed)"
)
live.console.print(
"[green]Result collection phase complete. Waiting for worker processes to join...[/green]"
)
for i, process_to_join in enumerate(worker_processes_list):
process_to_join.join(timeout=10) # Give a generous timeout for clean exit
if process_to_join.is_alive():
live.console.print(
f"[yellow]Warning: Worker process {process_to_join.pid} (task {i}) did not join cleanly. Terminating.[/yellow]"
)
process_to_join.terminate() # Force terminate if stuck
process_to_join.join() # Wait for termination
# Stop individual metric progress tasks if they are not transient or to ensure they are marked done
for task_id_val in metric_task_ids_map.values():
if not metric_specific_progress_display.tasks[task_id_val].finished:
metric_specific_progress_display.stop_task(task_id_val)
# Stop the entire metric_specific_progress_display if it's not needed anymore
metric_specific_progress_display.stop()
# Overall progress is not transient, so its task will remain. We can stop it.
if not overall_progress_display.tasks[0].finished:
overall_progress_display.stop_task(overall_task_id)
# overall_progress_display itself is not stopped to keep it visible as a summary.
# --- End of Rich Live context ---
# Final messages after Rich Live context has ended
rich_console.print(
"[bold green]All dedicated metric worker processes have completed processing.[/bold green]"
)
self.save_master_cache() # Save all collected results (including None for errors)
rich_console.print("\n--- Metric Execution Summary ---", style="bold cyan")
rich_console.print(f"Folders eligible for processing (complete): {folders_processed_count}")
if skipped_incomplete_folders > 0:
rich_console.print(
f"Folders skipped due to incompleteness: {skipped_incomplete_folders}"
)
rich_console.print(
f"Folders requiring metric calculations this run: {len(folders_requiring_metric_tasks)}"
)
rich_console.print(
f"Total individual metric-folder tasks processed: {num_results_collected} "
f"(out of {total_individual_metric_folder_tasks} identified)."
)
def load_per_image_results(self, folder_name: str, metric_name: str) -> dict | None:
"""
Loads per-image scores for a given folder and metric, used for LoFTR analysis.
Caches results in self.per_image_scores_cache to avoid redundant file I/O.
"""
# Check in-memory cache first
if (
folder_name in self.per_image_scores_cache
and metric_name in self.per_image_scores_cache[folder_name]
):
return self.per_image_scores_cache[folder_name][metric_name]
# Determine subdirectory for per-image cache based on metric type (masked or not)
# This matches the naming convention used in the __main__ block for cache setup.
is_masked_metric = metric_name in ["PSNR", "SSIM", "LPIPS"] # Example masked metrics
cache_subdir_name = metric_name.lower() + ("_masked" if is_masked_metric else "")
per_image_cache_file = (
self.per_image_cache_root_dir / cache_subdir_name / f"{folder_name}.json"
)
data = load_json_cache(per_image_cache_file) # Uses standard print for errors
if (
data
and isinstance(data, dict)
and "per_image" in data
and isinstance(data["per_image"], dict)
):
# Store in in-memory cache for future calls
self.per_image_scores_cache[folder_name][metric_name] = data["per_image"]
return data["per_image"]
# print(f"Debug: Per-image scores not found or invalid for {metric_name} in {folder_name} at {per_image_cache_file}")
return None
def run_loftr_ranking(self):
"""
Runs the LoFTR ranking script for suitable RealBench folders.
This method uses tqdm for its own progress as it's a sequential operation.
"""
if not self.loftr_script_path:
print("\nLoFTR ranking skipped: LoFTR script path is invalid or script not found.")
return
print("\n--- Running LoFTR Ranking ---", flush=True) # flush in case of subsequent tqdm
# Filter folders suitable for LoFTR: RealBench, non-gen, non-fp32, and complete
folders_for_loftr_ranking = [
folder_path
for folder_path in self.discovered_result_folders
if folder_path.name.startswith("RealBench")
and "gen" not in folder_path.name.lower() # case-insensitive check for "gen"
and "fp32" not in folder_path.name.lower() # case-insensitive check for "fp32"
and self.check_folder_contents(folder_path) # Ensure folder is complete
]
if not folders_for_loftr_ranking:
print("No suitable RealBench folders found for LoFTR ranking.")
return
print(f"Found {len(folders_for_loftr_ranking)} suitable folders for LoFTR ranking.")
num_ranked_successfully = 0
num_skipped_existing = 0
num_skipped_missing_ref = 0
num_errors = 0
force_loftr_recalc = (
"all" in self.force_recalc_metrics_list or "loftr" in self.force_recalc_metrics_list
)
print("Executing LoFTR ranking script for each suitable folder (this may take time)...")