Skip to content

Commit 8d4fbb9

Browse files
authored
Comparison script: Support multiple name patterns and regex (alisw#1020)
1 parent 1ed6415 commit 8d4fbb9

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

machine_learning_hep/utils/compare_root_files.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import argparse
2222
import math
23+
import re
2324
import sys
2425
from enum import Enum
2526
from itertools import permutations
@@ -68,7 +69,11 @@ def msg_fatal(message: str):
6869

6970

7071
def list_recursive(
71-
file: TDirectoryFile, objects: dict | None = None, path_dir: str = "", verbose: bool = False, name_pattern: str = ""
72+
file: TDirectoryFile,
73+
objects: dict | None = None,
74+
path_dir: str = "",
75+
verbose: bool = False,
76+
name_patterns: list[str] | None = None,
7277
) -> dict:
7378
"""Recursively load objects from a ROOT file into a dictionary."""
7479
if objects is None:
@@ -77,13 +82,13 @@ def list_recursive(
7782
name_obj = key.GetName()
7883
name_class = key.GetClassName()
7984
path_obj = f"{path_dir + '/' if path_dir else ''}{name_obj}"
80-
if name_pattern and name_pattern not in path_obj:
85+
if name_patterns and not any(re.search(rf"{pattern}", path_obj) for pattern in name_patterns):
8186
continue
8287
obj = file.Get(name_obj)
8388
if verbose:
8489
print(f"{path_obj}: {name_class}")
8590
if isinstance(obj, TDirectoryFile):
86-
list_recursive(obj, objects, path_obj, verbose, name_pattern)
91+
list_recursive(obj, objects, path_obj, verbose, name_patterns)
8792
else:
8893
objects[path_obj] = obj
8994
return objects
@@ -502,7 +507,9 @@ def main():
502507
parser.add_argument("-d", action="store_true", help="report and plot only different objects")
503508
parser.add_argument("-c", action="store_true", help="plot only common objects")
504509
parser.add_argument("-s", action="store_true", help="skip numeric comparison")
505-
parser.add_argument("-n", type=str, default="", help="name pattern (substring required in the object path)")
510+
parser.add_argument(
511+
"-n", type=str, nargs="+", default=None, help="name patterns (substrings required in the object path)"
512+
)
506513
parser.add_argument(
507514
"-t", type=int, help="tolerance (order of magnitude of the maximum acceptable relative difference of values)"
508515
)
@@ -521,7 +528,7 @@ def main():
521528
diff_only = args.d
522529
common_only = args.c
523530
skip_comparison = args.s
524-
name_pattern = args.n
531+
name_patterns = args.n
525532
mag_epsilon = None if args.t is None else args.t
526533
project = args.proj
527534
n_slices_max = args.slices
@@ -546,7 +553,7 @@ def main():
546553
key_i += f"_{i + 1}"
547554
# Load objects.
548555
print(f"\nLoading objects from file {path_i}.")
549-
objects[key_i] = list_recursive(file_i, verbose=verbose, name_pattern=name_pattern)
556+
objects[key_i] = list_recursive(file_i, verbose=verbose, name_patterns=name_patterns)
550557

551558
# Make projections.
552559
if project:
@@ -562,7 +569,7 @@ def main():
562569
list_names = sorted(set(objects[0].keys()).intersection(objects[1].keys()))
563570
else:
564571
list_names = sorted(set(objects[0].keys()).union(objects[1].keys()))
565-
dict_result = {name: True for name in list_names}
572+
dict_result = dict.fromkeys(list_names, True)
566573
else:
567574
same_structure, common_content, compared_all, same_content, dict_result = are_same_files(
568575
objects, verbose, diff_only, mag_epsilon

0 commit comments

Comments
 (0)