Skip to content

Commit b684e1a

Browse files
committed
Smaller Changes on Task Manager + Contextmanager for SITK IO
1 parent e1f4ec2 commit b684e1a

File tree

7 files changed

+66
-33
lines changed

7 files changed

+66
-33
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ stats/
55
dataset_cfg/
66
tests/temp/
77
nodes.md
8+
playground/
89

910
# Byte-compiled / optimized / DLL files
1011
__pycache__/

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@
187187
same "printed page" as the copyright notice for easier
188188
identification within third-party archives.
189189

190-
Copyright {yyyy} {name of copyright owner}
190+
Copyright [2025] Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
191191

192192
Licensed under the Apache License, Version 2.0 (the "License");
193193
you may not use this file except in compliance with the License.

src/vidata/file_manager/file_manager.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
recursive: bool = False,
4343
lazy_init: bool = False,
4444
):
45-
self.path = path
45+
self.path = Path(path)
4646
self.file_type = file_type
4747
self.pattern = pattern
4848
self.include_names = include_names
@@ -177,7 +177,7 @@ def get_name(self, file: str | int, with_file_type=True) -> str:
177177
"""Legacy alias for :meth:`name_from_path` (kept for backward compatibility)."""
178178
return self.name_from_path(file, with_file_type)
179179

180-
def name_from_path(self, file: str | int, include_ext: bool = True) -> str:
180+
def name_from_path(self, file: Path | str | int, include_ext: bool = True) -> str:
181181
"""
182182
Get the relative name of a file (e.g., 'subdir/sample.png').
183183
@@ -194,14 +194,16 @@ def name_from_path(self, file: str | int, include_ext: bool = True) -> str:
194194
Relative file name.
195195
"""
196196
if isinstance(file, int):
197-
file = str(self.files[file])
198-
name = (
199-
str(Path(file).relative_to(self.path))
200-
if not str(self.path).endswith(".json")
201-
else str(file)
202-
)
203-
if not include_ext:
204-
name = name.replace(self.file_type, "")
197+
file = self.files[file]
198+
if not isinstance(file, Path):
199+
file = Path(file)
200+
201+
name_pl = file.relative_to(self.path) if self.path.suffix != ".json" else file
202+
name = name_pl.as_posix()
203+
204+
if not include_ext and name.endswith(self.file_type):
205+
name = name[: -len(self.file_type)]
206+
205207
return name
206208

207209
def path_from_name(self, name: str | Path, include_ext=True):
@@ -211,7 +213,7 @@ def path_from_name(self, name: str | Path, include_ext=True):
211213
rel = Path(name)
212214
if include_ext and rel.suffix != self.file_type:
213215
rel = rel.with_suffix(self.file_type)
214-
if str(self.path).endswith(".json"):
216+
if self.path.suffix == ".json":
215217
return rel
216218
else:
217219
return (self.path / rel).resolve()

src/vidata/io/sitk_io.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import locale
2+
from contextlib import contextmanager
13
from pathlib import Path
24

35
import numpy as np
@@ -7,6 +9,20 @@
79
from vidata.utils.affine import build_affine
810

911

12+
@contextmanager
13+
def temporary_c_locale():
14+
# Save current LC_NUMERIC
15+
old_locale = locale.setlocale(locale.LC_NUMERIC, None)
16+
17+
try:
18+
# Switch to safe C locale
19+
locale.setlocale(locale.LC_NUMERIC, "C")
20+
yield
21+
finally:
22+
# Restore original locale
23+
locale.setlocale(locale.LC_NUMERIC, old_locale)
24+
25+
1026
@register_writer("image", ".nii.gz", ".nii", ".mha", ".nrrd", backend="sitk")
1127
@register_writer("mask", ".nii.gz", ".nii", ".mha", ".nrrd", backend="sitk")
1228
def save_sitk(data: np.ndarray, file: str | Path, metadata: dict | None = None) -> list[str]:
@@ -56,7 +72,9 @@ def load_sitk(file: str | Path) -> tuple[np.ndarray, dict]:
5672
- "direction": orientation matrix (np.ndarray)
5773
- "affine": computed affine matrix (np.ndarray)
5874
"""
59-
image = sitk.ReadImage(file)
75+
with temporary_c_locale():
76+
image = sitk.ReadImage(file)
77+
6078
array = sitk.GetArrayFromImage(image)
6179
ndims = len(array.shape)
6280

src/vidata/task_manager/multilabel_segmentation_manager.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,20 @@ def empty(size: tuple[int, ...], num_classes: int) -> np.ndarray:
1515
return np.zeros((num_classes, *size), dtype=np.uint8)
1616

1717
@staticmethod
18-
def class_ids(data: np.ndarray) -> np.ndarray:
18+
def class_ids(
19+
data: np.ndarray, return_counts: bool = False
20+
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
1921
"""
2022
Return class indices that are present in the mask (i.e., where at least one pixel is non-zero).
2123
"""
22-
return np.flatnonzero(data.reshape(data.shape[0], -1).any(axis=1))
23-
# return np.where(data.reshape(self.num_classes, -1).any(axis=1))[0]
24+
class_ids = np.flatnonzero(data.reshape(data.shape[0], -1).any(axis=1))
25+
26+
if return_counts:
27+
axes = tuple(range(1, data.ndim))
28+
counts = np.count_nonzero(data, axis=axes)
29+
return class_ids, counts[class_ids]
30+
31+
return class_ids
2432

2533
@staticmethod
2634
def class_count(data: np.ndarray, class_id: int) -> int:
@@ -30,10 +38,14 @@ def class_count(data: np.ndarray, class_id: int) -> int:
3038
return int(np.sum(data[class_id]))
3139

3240
@staticmethod
33-
def class_location(data: np.ndarray, class_id: int) -> tuple[np.ndarray, ...]:
41+
def class_location(
42+
data: np.ndarray, class_id: int, return_mask: bool = False
43+
) -> tuple[np.ndarray, ...] | np.ndarray:
3444
"""
3545
Return indices where the given class is active (non-zero).
3646
"""
47+
if return_mask:
48+
return data[class_id]
3749
return np.where(data[class_id] > 0)
3850

3951
@staticmethod
@@ -45,15 +57,3 @@ def spatial_dims(shape: np.ndarray) -> np.ndarray:
4557
def has_background():
4658
"""if the task has a dedicated background class --> is class 0 the bg class?"""
4759
return False
48-
49-
50-
if __name__ == "__main__":
51-
data = MultiLabelSegmentationManager.random((100, 100), 7)
52-
data = MultiLabelSegmentationManager.empty((100, 100), 7)
53-
data[1, 0, 5] = 1
54-
data[0, 0, 0] = 1
55-
data[5, 0, 0] = 1
56-
print(np.unique(data))
57-
print(data.shape)
58-
print(MultiLabelSegmentationManager.class_ids(data))
59-
print(MultiLabelSegmentationManager.class_location(data, 1))

src/vidata/task_manager/semantic_segmentation_manager.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,23 @@ def empty(size: tuple[int, ...], num_classes: int) -> np.ndarray:
1515
return np.zeros(size, dtype=np.uint8)
1616

1717
@staticmethod
18-
def class_ids(data: np.ndarray) -> np.ndarray:
18+
def class_ids(
19+
data: np.ndarray, return_counts: bool = False
20+
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
21+
if return_counts:
22+
return np.unique(data, return_counts=True)
1923
return np.unique(data)
2024

2125
@staticmethod
2226
def class_count(data: np.ndarray, class_id: int) -> int:
2327
return int(np.sum(data == class_id))
2428

2529
@staticmethod
26-
def class_location(data: np.ndarray, class_id: int) -> tuple[np.ndarray, ...]:
30+
def class_location(
31+
data: np.ndarray, class_id: int, return_mask: bool = False
32+
) -> tuple[np.ndarray, ...] | np.ndarray:
33+
if return_mask:
34+
return np.asarray(data == class_id) # data == class_id
2735
return np.where(data == class_id)
2836

2937
@staticmethod

src/vidata/task_manager/task_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ def empty(size: tuple[int, ...], num_classes: int) -> np.ndarray:
2020

2121
@staticmethod
2222
@abstractmethod
23-
def class_ids(data: np.ndarray) -> np.ndarray:
23+
def class_ids(
24+
data: np.ndarray, return_counts: bool = False
25+
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
2426
"""Return a sorted array of unique class IDs present in the data."""
2527

2628
@staticmethod
@@ -30,7 +32,9 @@ def class_count(data: np.ndarray, class_id: int) -> int:
3032

3133
@staticmethod
3234
@abstractmethod
33-
def class_location(data: np.ndarray, class_id: int) -> tuple[np.ndarray, ...]:
35+
def class_location(
36+
data: np.ndarray, class_id: int, return_mask: bool = False
37+
) -> tuple[np.ndarray, ...] | np.ndarray:
3438
"""Return the indices where the given class ID occurs."""
3539

3640
@staticmethod

0 commit comments

Comments
 (0)