Skip to content

Commit ba84b5f

Browse files
committed
minor improvements: save functions return file names, file manager file filtering more efficient, affine utils
1 parent 1d0e595 commit ba84b5f

File tree

12 files changed

+144
-42
lines changed

12 files changed

+144
-42
lines changed

src/vidata/analysis/image_analyzer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, data_loader: BaseLoader, file_manager: FileManager, nchannels
4747
def analyze_case(self, index, verbose=False):
4848
file = self.file_manager[index]
4949
data, meta = self.data_loader.load(file)
50-
50+
data = data[...] # To resolve memmap dtypes
5151
stats = {
5252
"name": file.name,
5353
"dtype": str(data.dtype),

src/vidata/analysis/label_analyzer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
def analyze_case(self, index, verbose=False):
3535
file = self.file_manager[index]
3636
data, meta = self.data_loader.load(file)
37-
# data=data.astype(int)
37+
data = data[...] # To resolve memmap dtypes
3838
data = data.astype(np.uint8)
3939

4040
stats = {

src/vidata/analysis/viz_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def adjust_layout(
2121
figure.update_layout(
2222
xaxis={
2323
"title": {"text": xaxis_title, "font": {"size": 18}}, # axis label font
24-
"tickfont": {"size": 14}, # tick labels
24+
# "tickfont": {"size": 14}, # tick labels
2525
}
2626
)
2727
else:

src/vidata/config_manager.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pathlib import Path
22

3-
from omegaconf import DictConfig
3+
from omegaconf import DictConfig, OmegaConf
44

55
from vidata.file_manager import FileManager, FileManagerStacked
66
from vidata.io import load_json
@@ -264,7 +264,6 @@ def file_manager(self, split: str | None = None, fold: int | None = None) -> Fil
264264
include_names = None
265265
if self.splits_file is not None and split is not None:
266266
include_names = self.resolve_splits_file(split, fold)
267-
268267
return manager_cls(
269268
path=_cfg["path"],
270269
file_type=_cfg["file_type"],
@@ -320,8 +319,11 @@ def task_manager(self) -> TaskManager:
320319

321320

322321
class ConfigManager:
323-
def __init__(self, config: dict | DictConfig):
324-
self.config = config
322+
def __init__(self, config: dict | DictConfig | str):
323+
if isinstance(config, str):
324+
self.config = OmegaConf.load(config)
325+
else:
326+
self.config = config
325327
self.layers = []
326328

327329
split_cfg = self.config.get("splits", {})
@@ -333,6 +335,7 @@ def __init__(self, config: dict | DictConfig):
333335
ovrds = split_cfg[k][layer_cfg["name"]]
334336
layer_split[k] = ovrds if ovrds is not None else {}
335337
lcm = LayerConfigManager(layer_cfg, layer_split, split_cfg.get("splits_file"))
338+
336339
self.layers.append(lcm)
337340

338341
@property
@@ -353,7 +356,6 @@ def __len__(self):
353356

354357

355358
if __name__ == "__main__":
356-
from omegaconf import DictConfig, OmegaConf
357359

358360
path = "../../../dataset_cfg/Cityscapes.yaml"
359361
cfg = dict(OmegaConf.load(path))

src/vidata/file_manager/file_manager.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,25 +38,24 @@ def __init__(
3838
self.pattern = pattern
3939
self.include_names = include_names
4040
self.exclude_names = exclude_names
41-
4241
self.collect_files()
4342
self.filter_files()
4443

4544
def filter_files(self):
4645
if self.include_names is not None:
46+
_files_re = [str(_file.relative_to(self.path)) for _file in self.files]
4747
self.files = [
4848
_file
49-
for _file in list(self.files)
50-
if any(_token in str(_file.relative_to(self.path)) for _token in self.include_names)
49+
for _file, rel in zip(list(self.files), _files_re, strict=False)
50+
if any(_token in rel for _token in self.include_names)
5151
]
5252

5353
if self.exclude_names is not None:
54+
_files_re = [str(_file.relative_to(self.path)) for _file in self.files]
5455
self.files = [
5556
_file
56-
for _file in list(self.files)
57-
if not any(
58-
_token in str(_file.relative_to(self.path)) for _token in self.exclude_names
59-
)
57+
for _file, rel in zip(list(self.files), _files_re, strict=False)
58+
if not any(_token in rel for _token in self.exclude_names)
6059
]
6160

6261
def collect_files(self):
@@ -73,6 +72,14 @@ def collect_files(self):
7372
files = list(Path(self.path).glob(pattern + self.file_type))
7473
self.files = natsorted(files, key=lambda p: p.name)
7574

75+
def get_name(self, file: str | int, with_file_type=True) -> str:
76+
if isinstance(file, int):
77+
file = str(self.files[file])
78+
name = str(Path(file).relative_to(self.path))
79+
if not with_file_type:
80+
name = name.replace(self.file_type, "")
81+
return name
82+
7683
def __getitem__(self, item: int):
7784
return self.files[item]
7885

src/vidata/io/blosc2_io.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
def save_blosc2(
2121
data: np.ndarray,
2222
file: str,
23-
patch_size: Union[tuple[int, int], tuple[int, int, int]],
23+
patch_size: Union[tuple[int, int], tuple[int, int, int]] | None = None,
2424
clevel: int = 8,
2525
nthreads: int = 8,
2626
codec: blosc2.Codec = blosc2.Codec.ZSTD,
2727
metadata: dict | None = None,
28-
):
28+
) -> list[str]:
2929
"""Saves a NumPy array to a Blosc2 file with specified compression parameters.
3030
3131
Args:
@@ -37,6 +37,18 @@ def save_blosc2(
3737
codec (blosc2.Codec, optional): Compression codec. Defaults to blosc2.Codec.ZSTD.
3838
metadata (Optional[dict], optional): Optional dictionary of metadata to attach. Defaults to None.
3939
"""
40+
41+
if patch_size is None:
42+
_is_float = np.issubdtype(data.dtype.type, np.floating)
43+
_is_2d = data.ndim == 2
44+
45+
# if _is_2d:
46+
base_patch_size = (512 if _is_float else 1024) if _is_2d else (64 if _is_float else 96)
47+
# else:
48+
# base_patch_size = 64 if _is_float else 96
49+
50+
patch_size = tuple([min(s, base_patch_size) for s in data.shape])
51+
4052
blocks, chunks = comp_blosc2_params(data.shape, patch_size, data.itemsize)
4153
blosc2.set_nthreads(nthreads)
4254
blosc2.asarray(
@@ -48,6 +60,7 @@ def save_blosc2(
4860
mmap_mode="w+",
4961
meta=metadata,
5062
)
63+
return [file]
5164

5265

5366
@register_loader("image", ".b2nd", backend="blosc2")
@@ -74,7 +87,7 @@ def load_blosc2(file: str, nthreads: int = 1) -> tuple[blosc2.NDArray, dict]:
7487
def save_blosc2pkl(
7588
data: np.ndarray,
7689
file: str,
77-
patch_size: Union[tuple[int, int], tuple[int, int, int]],
90+
patch_size: Union[tuple[int, int], tuple[int, int, int]] | None = None,
7891
clevel: int = 8,
7992
nthreads: int = 8,
8093
codec: blosc2.Codec = blosc2.Codec.ZSTD,
@@ -92,7 +105,9 @@ def save_blosc2pkl(
92105
metadata (Optional[dict], optional): Optional dictionary of metadata to attach. Defaults to None.
93106
"""
94107
save_blosc2(data, file, patch_size=patch_size, clevel=clevel, nthreads=nthreads, codec=codec)
95-
save_pickle(metadata, str(file).replace(".b2nd", ".pkl"))
108+
file_pkl = str(file).replace(".b2nd", ".pkl")
109+
save_pickle(metadata, file_pkl)
110+
return [file, file_pkl]
96111

97112

98113
@register_loader("image", ".b2nd", backend="blosc2pkl")

src/vidata/io/image_io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ def load_image(file: str):
1313

1414
@register_writer("image", ".png", ".jpg", ".jpeg", ".bmp", backend="imageio")
1515
@register_writer("mask", ".png", ".bmp", backend="imageio")
16-
def save_image(data: np.ndarray, file: str):
16+
def save_image(data: np.ndarray, file: str) -> list[str]:
1717
iio.imwrite(file, data)
18+
return [file]
1819

1920

2021
# @register_loader("image", ".png", ".jpg", ".jpeg", ".bmp")

src/vidata/io/nib_io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
@register_writer("image", ".nii.gz", ".nii", backend="nibabel")
99
@register_writer("mask", ".nii.gz", ".nii", backend="nibabel")
10-
def save_nib(data, file, metadata: dict | None = None) -> None:
10+
def save_nib(data, file, metadata: dict | None = None) -> list[str]:
1111
"""
1212
Save a NumPy array and SITK-style metadata to a NIfTI file using nibabel.
1313
@@ -56,6 +56,7 @@ def save_nib(data, file, metadata: dict | None = None) -> None:
5656

5757
image_nib = nib.Nifti1Image(data, affine=affine_nib)
5858
nib.save(image_nib, str(file))
59+
return [file]
5960

6061

6162
@register_loader("image", ".nii.gz", ".nii", backend="nibabel")

src/vidata/io/numpy_io.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,70 @@
1-
from pathlib import Path
2-
from typing import Union
3-
41
import numpy as np
52

63
from vidata.registry import register_loader, register_writer
74

85

96
@register_loader("image", ".npy", backend="numpy")
107
@register_loader("mask", ".npy", backend="numpy")
11-
def load_npy(path: str) -> np.ndarray:
8+
def load_npy(file: str) -> np.ndarray:
129
"""Load a NumPy array from a .npy file.
1310
1411
Args:
15-
path (str): Path to the .npy file.
12+
file (str): Path to the .npy file.
1613
1714
Returns:
1815
np.ndarray: Loaded NumPy array.
1916
"""
20-
return np.load(path, allow_pickle=False), {}
17+
return np.load(file, allow_pickle=False), {}
2118

2219

2320
@register_writer("image", ".npy", backend="numpy")
2421
@register_writer("mask", ".npy", backend="numpy")
25-
def save_npy(array: np.ndarray, path: Union[str, Path], *args, **kwargs) -> None:
22+
def save_npy(array: np.ndarray, file: str, *args, **kwargs) -> list[str]:
2623
"""Save a NumPy array to a .npy file.
2724
2825
Args:
2926
array (np.ndarray): NumPy array to save.
30-
path (str): Output file path.
27+
file (str): Output file file.
3128
"""
32-
np.save(path, array)
29+
np.save(file, array)
30+
return [file]
3331

3432

3533
@register_loader("image", ".npz", backend="numpy")
3634
@register_loader("mask", ".npz", backend="numpy")
37-
def load_npz(path: str) -> tuple[dict[str, np.ndarray], dict]:
35+
def load_npz(file: str) -> tuple[dict[str, np.ndarray], dict]:
3836
"""Load multiple arrays from a .npz file into a dictionary.
3937
4038
Args:
41-
path (str): Path to the .npz file.
39+
file (str): Path to the .npz file.
4240
4341
Returns:
4442
dict[str, np.ndarray]: dictionary mapping keys to arrays.
4543
"""
46-
with np.load(path) as data:
44+
with np.load(file) as data:
4745
return {key: data[key] for key in data.files}, {}
4846

4947

5048
@register_writer("image", ".npz", backend="numpy")
5149
@register_writer("mask", ".npz", backend="numpy")
5250
def save_npz(
53-
data_dict: dict[str, np.ndarray], path: str, compress: bool = True, *args, **kwargs
54-
) -> None:
51+
data_dict: dict[str, np.ndarray], file: str, compress: bool = True, *args, **kwargs
52+
) -> list[str]:
5553
"""Save multiple NumPy arrays to a .npz file.
5654
5755
Args:
5856
data_dict (dict[str, np.ndarray]): dictionary of arrays to save.
59-
path (str): Output file path.
57+
file (str): Output file file.
6058
compress (bool, optional): Whether to use compressed format. Defaults to True.
6159
"""
6260
if compress:
6361
if isinstance(data_dict, dict):
64-
np.savez_compressed(path, **data_dict)
62+
np.savez_compressed(file, **data_dict)
6563
else:
66-
np.savez_compressed(path, data_dict)
64+
np.savez_compressed(file, data_dict)
6765
else:
6866
if isinstance(data_dict, dict):
69-
np.savez(path, **data_dict)
67+
np.savez(file, **data_dict)
7068
else:
71-
np.savez(path, data_dict)
69+
np.savez(file, data_dict)
70+
return [file]

src/vidata/io/sitk_io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
@register_writer("image", ".nii.gz", ".nii", ".mha", ".nrrd", backend="sitk")
99
@register_writer("mask", ".nii.gz", ".nii", ".mha", ".nrrd", backend="sitk")
10-
def save_sitk(data, file, metadata: dict | None = None) -> None:
10+
def save_sitk(data: np.ndarray, file: str, metadata: dict | None = None) -> list[str]:
1111
"""Save a NumPy array as a medical image file using SimpleITK.
1212
1313
Args:
@@ -34,6 +34,7 @@ def save_sitk(data, file, metadata: dict | None = None) -> None:
3434
image_sitk.SetDirection(direction.flatten().tolist()[::-1])
3535

3636
sitk.WriteImage(image_sitk, str(file), useCompression=True)
37+
return [file]
3738

3839

3940
@register_loader("image", ".nii.gz", ".nii", ".mha", ".nrrd", backend="sitk")

0 commit comments

Comments
 (0)