Skip to content

Commit c52b10e

Browse files
committed
Add lazy loading for File_Manager for better multiprocessing; Add cv2 IO
1 parent 3902fac commit c52b10e

File tree

8 files changed

+263
-40
lines changed

8 files changed

+263
-40
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
"pyyaml",
3838
"seaborn",
3939
"scikit-learn",
40+
"opencv-python-headless"
4041
]
4142

4243
[project.optional-dependencies]

src/vidata/analysis/image_analyzer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def __init__(self, data_loader: BaseLoader, file_manager: FileManager, nchannels
4444
self.stats = None
4545
self.global_stats = None
4646

47-
def analyze_case(self, index, verbose=False):
48-
file = self.file_manager[index]
47+
def analyze_case(self, file, verbose=False):
48+
# file = self.file_manager[index]
4949
data, meta = self.data_loader.load(file)
5050
data = data[...] # To resolve memmap dtypes
5151
stats = {
@@ -79,7 +79,7 @@ def analyze_case(self, index, verbose=False):
7979
def run(self, n_processes=8, progressbar=True, verbose=False):
8080
stats = multiprocess_iter(
8181
self.analyze_case,
82-
iterables={"index": np.arange(0, len(self.file_manager))},
82+
iterables={"file": self.file_manager},
8383
const={"verbose": verbose},
8484
p=n_processes,
8585
progressbar=progressbar,

src/vidata/analysis/label_analyzer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def __init__(
3131
self.n_classes = n_classes
3232
self.ignore_bg = ignore_bg
3333

34-
def analyze_case(self, index, verbose=False):
35-
file = self.file_manager[index]
34+
def analyze_case(self, file, verbose=False):
35+
# file = self.file_manager[index]
3636
data, meta = self.data_loader.load(file)
3737
data = data[...] # To resolve memmap dtypes
3838
data = data.astype(np.uint8)
@@ -58,7 +58,7 @@ def analyze_case(self, index, verbose=False):
5858
def run(self, n_processes=8, progressbar=True, verbose=False):
5959
stats = multiprocess_iter(
6060
self.analyze_case,
61-
iterables={"index": np.arange(0, len(self.file_manager))},
61+
iterables={"file": self.file_manager},
6262
const={"verbose": verbose},
6363
p=n_processes,
6464
progressbar=progressbar,
@@ -173,13 +173,15 @@ def plot(self, path, name=""):
173173
# --- Size - Frequency Plot --- #
174174
colors = get_colormap("tab10", len(class_cnt), as_uint=True)
175175
fig = go.Figure()
176-
for cnt, size, name, col in zip(class_cnt, class_size, categories, colors, strict=False):
176+
for cnt, size, legend_name, col in zip(
177+
class_cnt, class_size, categories, colors, strict=False
178+
):
177179
fig.add_trace(
178180
go.Scatter(
179181
x=[cnt],
180182
y=[size],
181183
mode="markers",
182-
name=name, # ← legend label
184+
name=legend_name, # ← legend label
183185
marker={
184186
"size": 15,
185187
"color": f"rgb{col}",

src/vidata/file_manager/file_manager.py

Lines changed: 160 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class FileManager:
1111
"""
1212
Flexible file collector with optional patterns and name based filtering.
13+
Also supports lazy loading (useful for multiprocessing).
1314
1415
Parameters
1516
----------
@@ -25,6 +26,8 @@ class FileManager:
2526
Drop files whose RELATIVE path contains ANY of these substrings. (Exclude wins.)
2627
recursive: bool
2728
Whether to recursively search subdirectories.
29+
lazy_init : bool
30+
If True, defer file collection until the first access (default: False).
2831
"""
2932

3033
def __init__(
@@ -35,56 +38,154 @@ def __init__(
3538
include_names: list[str] | None = None,
3639
exclude_names: list[str] | None = None,
3740
recursive: bool = False,
41+
lazy_init: bool = False,
3842
):
3943
self.path = path
4044
self.file_type = file_type
4145
self.pattern = pattern
4246
self.include_names = include_names
4347
self.exclude_names = exclude_names
4448
self.recursive = recursive
45-
self.collect_files()
46-
self.filter_files()
4749

48-
def filter_files(self):
49-
if self.include_names is not None:
50-
_files_re = [str(_file.relative_to(self.path)) for _file in self.files]
51-
self.files = [
50+
self._files: list[Path] | None
51+
if not lazy_init:
52+
self.refresh()
53+
else:
54+
self._files = None
55+
56+
def refresh(self):
57+
"""
58+
(Re)collect and filter files immediately.
59+
60+
This method rebuilds the internal file list by scanning the directory and
61+
applying inclusion/exclusion filters.
62+
"""
63+
self._files = self.collect_files(self.path, self.file_type, self.pattern, self.recursive)
64+
self._files = self.filter_files(
65+
self._files, self.path, self.include_names, self.exclude_names
66+
)
67+
68+
@property
69+
def files(self) -> list[Path]:
70+
"""
71+
Lazily returns the collected file list.
72+
73+
If `lazy_init=True` was set and the files have not yet been collected,
74+
this property will automatically trigger a collection.
75+
"""
76+
if self._files is None: # Lazy loading
77+
self.refresh()
78+
assert self._files is not None
79+
return self._files
80+
81+
@files.setter
82+
def files(self, value: list[Path]):
83+
"""Directly override the internal file list (advanced use only)."""
84+
self._files = value
85+
86+
@staticmethod
87+
def filter_files(
88+
files: list[Path],
89+
path: Path,
90+
include_names: list[str] | None = None,
91+
exclude_names: list[str] | None = None,
92+
) -> list[Path]:
93+
"""
94+
Filter a list of files based on inclusion or exclusion substrings.
95+
96+
Parameters
97+
----------
98+
files : list[Path]
99+
Input file list.
100+
path : Path
101+
Root path used to compute relative paths for filtering.
102+
include_names : list[str] | None
103+
Substrings; keep files containing any of these in their relative path.
104+
exclude_names : list[str] | None
105+
Substrings; remove files containing any of these in their relative path.
106+
107+
Returns
108+
-------
109+
list[Path]
110+
Filtered file list.
111+
"""
112+
if include_names is not None:
113+
_files_re = [str(_file.relative_to(path)) for _file in files]
114+
files = [
52115
_file
53-
for _file, rel in zip(list(self.files), _files_re, strict=False)
54-
if any(_token in rel for _token in self.include_names)
116+
for _file, rel in zip(list(files), _files_re, strict=False)
117+
if any(_token in rel for _token in include_names)
55118
]
56119

57-
if self.exclude_names is not None:
58-
_files_re = [str(_file.relative_to(self.path)) for _file in self.files]
59-
self.files = [
120+
if exclude_names is not None:
121+
_files_re = [str(_file.relative_to(path)) for _file in files]
122+
files = [
60123
_file
61-
for _file, rel in zip(list(self.files), _files_re, strict=False)
62-
if not any(_token in rel for _token in self.exclude_names)
124+
for _file, rel in zip(list(files), _files_re, strict=False)
125+
if not any(_token in rel for _token in exclude_names)
63126
]
127+
return files
64128

65-
def collect_files(self):
66-
if self.file_type == "" or self.path == "":
67-
self.files = []
68-
return
129+
@staticmethod
130+
def collect_files(
131+
path: Path, file_type: str, pattern: str | None, recursive: bool = False
132+
) -> list[Path]:
133+
"""
134+
Collect files under the given directory according to a pattern and extension.
135+
136+
Parameters
137+
----------
138+
path : Path
139+
Root directory to search.
140+
file_type : str
141+
File extension to match (e.g., ".png").
142+
pattern : str | None
143+
Glob-like pattern (e.g., "*_image").
144+
recursive : bool, optional
145+
Whether to recursively search subdirectories.
69146
70-
if self.pattern is None:
147+
Returns
148+
-------
149+
list[Path]
150+
Naturally sorted list of file paths.
151+
"""
152+
if file_type == "" or path == "":
153+
return []
154+
155+
if pattern is None:
71156
pattern = "*"
72-
elif "*" not in self.pattern:
73-
pattern = "*" + self.pattern
157+
elif "*" not in pattern:
158+
pattern = "*" + pattern
74159
else:
75-
pattern = self.pattern
160+
pattern = pattern
76161

77-
if self.recursive:
78-
files = list(Path(self.path).rglob(pattern + self.file_type))
162+
if recursive:
163+
files = list(Path(path).rglob(pattern + file_type))
79164
else:
80-
files = list(Path(self.path).glob(pattern + self.file_type))
81-
self.files = natsorted(files, key=lambda p: p.name)
165+
files = list(Path(path).glob(pattern + file_type))
166+
# self.files = natsorted(files, key=lambda p: p.name)
167+
return natsorted(files, key=lambda p: p.name)
82168

83169
def get_name(self, file: str | int, with_file_type=True) -> str:
84-
"""Just keep this for backwards compatibility"""
170+
"""Legacy alias for :meth:`name_from_path` (kept for backward compatibility)."""
85171
return self.name_from_path(file, with_file_type)
86172

87173
def name_from_path(self, file: str | int, include_ext: bool = True) -> str:
174+
"""
175+
Get the relative name of a file (e.g., 'subdir/sample.png').
176+
177+
Parameters
178+
----------
179+
file : str | int
180+
File path or index into the internal file list.
181+
include_ext : bool
182+
Whether to keep the file extension.
183+
184+
Returns
185+
-------
186+
str
187+
Relative file name.
188+
"""
88189
if isinstance(file, int):
89190
file = str(self.files[file])
90191
name = str(Path(file).relative_to(self.path))
@@ -93,6 +194,9 @@ def name_from_path(self, file: str | int, include_ext: bool = True) -> str:
93194
return name
94195

95196
def path_from_name(self, name: str | Path, include_ext=True):
197+
"""
198+
Convert a relative name (as from :meth:`name_from_path`) to an absolute path.
199+
"""
96200
rel = Path(name)
97201
if include_ext and rel.suffix != self.file_type:
98202
rel = rel.with_suffix(self.file_type)
@@ -107,6 +211,36 @@ def __len__(self):
107211
def __iter__(self):
108212
return iter(self.files)
109213

214+
def __getstate__(self):
215+
"""
216+
Make the object lightweight for pickling.
217+
218+
The file list is omitted to reduce memory footprint when the object is
219+
sent to subprocesses. Workers can rebuild it lazily on first access.
220+
"""
221+
return {
222+
"path": str(self.path),
223+
"file_type": self.file_type,
224+
"pattern": self.pattern,
225+
"include_names": self.include_names,
226+
"exclude_names": self.exclude_names,
227+
"recursive": self.recursive,
228+
"_files": None,
229+
}
230+
231+
def __setstate__(self, state):
232+
"""
233+
Restore object state after unpickling (used in multiprocessing).
234+
The file list will be lazily rebuilt on first access.
235+
"""
236+
self.path = Path(state["path"])
237+
self.file_type = state["file_type"]
238+
self.pattern = state["pattern"]
239+
self.include_names = state["include_names"]
240+
self.exclude_names = state["exclude_names"]
241+
self.recursive = state["recursive"]
242+
self._files = state.get("_files", None)
243+
110244

111245
class FileManagerStacked(FileManager):
112246
"""

src/vidata/io/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# isort: skip_file # order matters, first ones in list are the defaults
22
# ruff: noqa: I001, I002 # disable Ruff's import-sorting checks for this file
3-
from .image_io import load_image, save_image
3+
from .image_io import load_image, save_image, load_imageRGB
4+
from .cv2_io import load_cv2, save_cv2, load_cv2RGB, save_cv2RGB
45
from .sitk_io import load_sitk, save_sitk
56
from .nib_io import load_nib, save_nib, load_nibRO, save_nibRO
67
from .tif_io import load_tif, save_tif
78
from .blosc2_io import load_blosc2, load_blosc2pkl, save_blosc2, save_blosc2pkl
89
from .numpy_io import load_npy, load_npz, save_npy, save_npz
9-
from .json_io import load_json, save_json
10+
from .json_io import load_json, save_json, load_jsongz, save_jsongz
1011
from .pickle_io import load_pickle, save_pickle
1112
from .txt_io import load_txt, save_txt
1213
from .yaml_io import load_yaml, save_yaml
@@ -26,6 +27,11 @@
2627
"save_tif",
2728
"load_image",
2829
"save_image",
30+
"load_imageRGB",
31+
"load_cv2",
32+
"save_cv2",
33+
"load_cv2RGB",
34+
"save_cv2RGB",
2935
"load_npy",
3036
"save_npy",
3137
"load_npz",
@@ -34,6 +40,8 @@
3440
"save_yaml",
3541
"load_json",
3642
"save_json",
43+
"load_jsongz",
44+
"save_jsongz",
3745
"load_pickle",
3846
"save_pickle",
3947
"load_txt",

src/vidata/io/cv2_io.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from pathlib import Path
2+
3+
import cv2
4+
import numpy as np
5+
6+
from vidata.registry import register_loader, register_writer
7+
8+
cv2.setNumThreads(0)
9+
10+
11+
@register_loader("image", ".png", ".jpg", ".jpeg", ".bmp", backend="cv2")
12+
@register_loader("mask", ".png", ".bmp", backend="cv2")
13+
def load_cv2(file: str | Path):
14+
data = cv2.imread(file, cv2.IMREAD_UNCHANGED)
15+
return data, {}
16+
17+
18+
@register_writer("image", ".png", ".jpg", ".jpeg", ".bmp", backend="cv2")
19+
@register_writer("mask", ".png", ".bmp", backend="cv2")
20+
def save_cv2(data: np.ndarray, file: str | Path) -> list[str]:
21+
cv2.imwrite(file, data)
22+
return [str(file)]
23+
24+
25+
@register_writer("image", ".png", ".jpg", ".jpeg", ".bmp", backend="cv2RGB")
26+
def save_cv2RGB(data: np.ndarray, file: str | Path) -> list[str]:
27+
data = cv2.cvtColor(data, cv2.COLOR_RGB2BGR)
28+
cv2.imwrite(file, data)
29+
return [str(file)]
30+
31+
32+
@register_loader("image", ".png", ".jpg", ".jpeg", ".bmp", backend="cv2RGB")
33+
def load_cv2RGB(file: str | Path):
34+
data = cv2.imread(file, cv2.IMREAD_COLOR)
35+
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB) # BGR -> RGB
36+
return data, {}

0 commit comments

Comments
 (0)