Skip to content

Commit 969d319

Browse files
Breaking: fully separate file_keys and file_list configuration path to guarantee proper parsing
1 parent ecf0152 commit 969d319

File tree

4 files changed

+61
-38
lines changed

4 files changed

+61
-38
lines changed

src/spine/bin/cli.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,23 @@
55
import os
66
import pathlib
77
import sys
8+
from typing import List, Optional
89

910
from spine.config import load_config
1011
from spine.config.loader import parse_value, set_nested_value
1112

1213

1314
def main(
14-
config,
15-
source,
16-
source_list,
17-
output,
18-
n,
19-
nskip,
20-
log_dir,
21-
weight_prefix,
22-
weight_path,
23-
config_overrides,
15+
config: str,
16+
source: List[str],
17+
source_list: str,
18+
output: str,
19+
n: int,
20+
nskip: int,
21+
log_dir: str,
22+
weight_prefix: str,
23+
weight_path: str,
24+
config_overrides: List[str],
2425
):
2526
"""Main driver for training/validation/inference/analysis.
2627
@@ -47,7 +48,7 @@ def main(
4748
weight_prefix : str
4849
Path to the directory for storing the training weights
4950
weight_path : str
50-
Path string a weight file or pattern for multiple weight files to load
51+
Path to a weight file or pattern for multiple weight files to load
5152
the model weights
5253
config_overrides : List[str]
5354
List of config overrides in the form "key.path=value"
@@ -76,11 +77,19 @@ def main(
7677
assert "io" in cfg, "Must provide an `io` block in the configuration."
7778

7879
# Override the input/output command-line information into the configuration
79-
if (source is not None and len(source) > 0) or source_list is not None:
80+
if source is not None and len(source) > 0:
8081
if "reader" in cfg["io"]:
81-
cfg["io"]["reader"]["file_keys"] = source or source_list
82+
cfg["io"]["reader"]["file_keys"] = source
8283
elif "loader" in cfg["io"]:
83-
cfg["io"]["loader"]["dataset"]["file_keys"] = source or source_list
84+
cfg["io"]["loader"]["dataset"]["file_keys"] = source
85+
else:
86+
raise KeyError("Must specify `loader` or `reader` in the `io` block.")
87+
88+
if source_list is not None:
89+
if "reader" in cfg["io"]:
90+
cfg["io"]["reader"]["file_list"] = source_list
91+
elif "loader" in cfg["io"]:
92+
cfg["io"]["loader"]["dataset"]["file_list"] = source_list
8493
else:
8594
raise KeyError("Must specify `loader` or `reader` in the `io` block.")
8695

@@ -242,15 +251,15 @@ def cli():
242251
# For actual training/inference, call the main function
243252
main(
244253
config=config_file,
245-
source=args.source or [],
254+
source=args.source,
246255
source_list=args.source_list,
247256
output=args.output,
248257
n=args.iterations,
249258
nskip=args.nskip,
250259
log_dir=args.log_dir,
251260
weight_prefix=args.weight_prefix,
252261
weight_path=args.weight_path,
253-
config_overrides=args.config_overrides or [],
262+
config_overrides=args.config_overrides,
254263
)
255264

256265

src/spine/io/core/read/base.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,39 +106,44 @@ def get_run_event(self, run: int, subrun: int, event: int) -> Dict[str, Any]:
106106

107107
def process_file_paths(
108108
self,
109-
file_keys: List[str],
109+
file_keys: Optional[Union[str, List[str]]] = None,
110+
file_list: Optional[str] = None,
110111
limit_num_files: Optional[int] = None,
111112
max_print_files: int = 10,
112113
) -> None:
113114
"""Process list of files.
114115
115116
Parameters
116117
----------
117-
file_keys : list
118-
List of paths to the HDF5 files to be read
118+
file_keys : Union[str, List[str]], optional
119+
Path or list of paths to the HDF5 files to be read
120+
file_list : str, optional
121+
Path to a text file containing a list of file paths to be read
119122
limit_num_files : int, optional
120123
Integer limiting number of files to be taken per data directory
121124
max_print_files : int, default 10
122125
Maximum number of loaded file names to be printed
123126
"""
124127
# Some basic checks
125-
assert file_keys is not None, "No input `file_keys` provided, abort."
128+
assert (file_keys is not None) ^ (
129+
file_list is not None
130+
), "Must provide either `file_keys` or `file_list` to process files."
126131
assert (
127132
limit_num_files is None or limit_num_files > 0
128133
), "If `limit_num_files` is provided, it must be larger than 0."
129134

130-
# If the file_keys points to a single text file, it must be a text
131-
# file containing a list of file paths. Parse it to a list.
132-
if isinstance(file_keys, str) and os.path.splitext(file_keys)[-1] == ".txt":
135+
# When using a file list in text format, read it and parse to a list
136+
if file_list is not None:
133137
# If the file list is a text file, extract the list of paths
134-
assert os.path.isfile(file_keys), (
135-
"If the `file_keys` are specified as a single string, "
136-
"it must be the path to a text file with a file list."
138+
assert os.path.isfile(file_list), (
139+
"If the `file_list` arguments is provided, it must point to a valid "
140+
"path to a text file with a file list."
137141
)
138-
with open(file_keys, "r", encoding="utf-8") as f:
142+
with open(file_list, "r", encoding="utf-8") as f:
139143
file_keys = f.read().splitlines()
140144

141145
# Convert the file keys to a list of file paths with glob
146+
assert file_keys is not None # For the linter's sake
142147
self.file_paths = []
143148
if isinstance(file_keys, str):
144149
file_keys = [file_keys]

src/spine/io/core/read/hdf5.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Contains a reader class dedicated to loading data from HDF5 files."""
22

33
from dataclasses import fields
4-
from typing import Any, Dict, List, Optional
4+
from typing import Any, Dict, List, Optional, Union
55
from warnings import warn
66

77
import h5py
@@ -34,7 +34,8 @@ class HDF5Reader(ReaderBase):
3434

3535
def __init__(
3636
self,
37-
file_keys: List[str],
37+
file_keys: Optional[Union[str, List[str]]] = None,
38+
file_list: Optional[str] = None,
3839
limit_num_files: Optional[int] = None,
3940
max_print_files: int = 10,
4041
n_entry: Optional[int] = None,
@@ -53,8 +54,10 @@ def __init__(
5354
5455
Parameters
5556
----------
56-
file_keys : List[str]
57-
List of paths to the HDF5 files to be read
57+
file_keys : Union[str, List[str]], optional
58+
Path or list of paths to the HDF5 files to be read
59+
file_list : str, optional
60+
Path to a text file containing a list of file paths to be read
5861
limit_num_files : Optional[int], optional
5962
Integer limiting number of files to be taken per data directory
6063
max_print_files : int, default 10
@@ -86,7 +89,7 @@ def __init__(
8689
If `True`, allows missing entries in the entry or event list
8790
"""
8891
# Process the list of files
89-
self.process_file_paths(file_keys, limit_num_files, max_print_files)
92+
self.process_file_paths(file_keys, file_list, limit_num_files, max_print_files)
9093

9194
# If an entry list is requested based on run/subrun/event ID, create map
9295
if run_event_list is not None or skip_run_event_list is not None:

src/spine/io/core/read/larcv.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Contains a reader class dedicated to loading data from LArCV files."""
22

3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List, Optional, Union
44

55
import numpy as np
66

@@ -35,8 +35,9 @@ class LArCVReader(ReaderBase):
3535

3636
def __init__(
3737
self,
38-
file_keys: List[str],
39-
tree_keys: List[str],
38+
file_keys: Optional[Union[str, List[str]]] = None,
39+
file_list: Optional[str] = None,
40+
tree_keys: Optional[List[str]] = None,
4041
limit_num_files: Optional[int] = None,
4142
max_print_files: int = 10,
4243
n_entry: Optional[int] = None,
@@ -53,8 +54,10 @@ def __init__(
5354
5455
Parameters
5556
----------
56-
file_keys : list
57-
List of paths to the HDF5 files to be read
57+
file_keys : Union[str, List[str]], optional
58+
Path or list of paths to the HDF5 files to be read
59+
file_list : str, optional
60+
Path to a text file containing a list of file paths to be read
5861
tree_keys : List[str]
5962
List of data keys to load from the LArCV files
6063
limit_num_files : Optional[int], optional
@@ -82,7 +85,10 @@ def __init__(
8285
If `True`, allows missing entries in the entry or event list
8386
"""
8487
# Process the file_paths
85-
self.process_file_paths(file_keys, limit_num_files, max_print_files)
88+
self.process_file_paths(file_keys, file_list, limit_num_files, max_print_files)
89+
assert (
90+
tree_keys is not None and len(tree_keys) > 0
91+
), "No input `tree_keys` provided, abort."
8692

8793
# If an entry list is requested based on run/subrun/event ID, create map
8894
if run_event_list is not None or skip_run_event_list is not None:

0 commit comments

Comments
 (0)