Skip to content

Commit 1d0e595

Browse files
committed
Add: get_writer to LayerConfigManager
1 parent d416e5f commit 1d0e595

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

src/vidata/config_manager.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717
SemanticSegmentationManager,
1818
TaskManager,
1919
)
20+
from vidata.writers import (
21+
BaseWriter,
22+
ImageStackWriter,
23+
ImageWriter,
24+
MultilabelStackedWriter,
25+
MultilabelWriter,
26+
SemSegWriter,
27+
)
2028

2129
_VALID_SPLITS = ["train", "val", "test"]
2230
_IMAGE_LAYERS = {"image"}
@@ -30,6 +38,15 @@
3038
"image": ImageStackLoader,
3139
"multilabel": MultilabelStackedLoader,
3240
}
41+
_WRITER_MAPPING: dict[str, type[BaseWriter]] = {
42+
"image": ImageWriter,
43+
"semseg": SemSegWriter,
44+
"multilabel": MultilabelWriter,
45+
}
46+
_STACKED_WRITER_MAPPING: dict[str, type[BaseWriter]] = {
47+
"image": ImageStackWriter,
48+
"multilabel": MultilabelStackedWriter,
49+
}
3350

3451

3552
class LayerConfigManager:
@@ -274,6 +291,25 @@ def data_loader(self) -> BaseLoader:
274291

275292
return loader_cls(**args)
276293

294+
def data_writer(self) -> BaseWriter:
295+
reg = _STACKED_WRITER_MAPPING if self.file_stack else _WRITER_MAPPING
296+
try:
297+
writer_cls = reg[self.type.lower()]
298+
except KeyError as err:
299+
raise ValueError(f"type {self.type} is not supported for layer {self.name}") from err
300+
301+
args = {
302+
"ftype": self.file_type,
303+
"backend": self.backend,
304+
}
305+
306+
if self.type.lower() in _IMAGE_LAYERS:
307+
args["channels"] = self.channels
308+
elif self.type.lower() in _LABEL_LAYERS:
309+
args["num_classes"] = self.classes
310+
311+
return writer_cls(**args)
312+
277313
def task_manager(self) -> TaskManager:
278314
if self.type.lower() == "semseg":
279315
return SemanticSegmentationManager()

tests/config/test_config_manager.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from vidata.file_manager import FileManager
77
from vidata.io import save_json
88
from vidata.loaders import ImageLoader, MultilabelStackedLoader, SemSegLoader
9+
from vidata.writers import ImageWriter, MultilabelStackedWriter, SemSegWriter
910

1011
LEN_VAL = 4
1112
LEN_TRAIN = 6
@@ -141,13 +142,19 @@ def test_config_manager_splitfile(simple_config):
141142
assert layer.file_stack == l_conf.get("file_stack", False)
142143
assert layer.backend == l_conf["backend"]
143144

144-
for i, (ln, loader) in enumerate(
145-
[("Images", ImageLoader), ("Labels", SemSegLoader), ("MLLabels", MultilabelStackedLoader)]
145+
for i, (ln, loader, writer) in enumerate(
146+
[
147+
("Images", ImageLoader, ImageWriter),
148+
("Labels", SemSegLoader, SemSegWriter),
149+
("MLLabels", MultilabelStackedLoader, MultilabelStackedWriter),
150+
]
146151
):
147152
layer = cm.layer(ln)
148153

149154
dl = layer.data_loader()
150155
assert isinstance(dl, loader)
156+
df = layer.data_writer()
157+
assert isinstance(df, writer)
151158

152159
l_conf = layer.config()
153160
l_conf["pattern"] = simple_config["layers"][i]["pattern"]
@@ -188,13 +195,19 @@ def test_config_manager_splitfile_fold(simple_config):
188195
assert layer.file_stack == l_conf.get("file_stack", False)
189196
assert layer.backend == l_conf["backend"]
190197

191-
for i, (ln, loader) in enumerate(
192-
[("Images", ImageLoader), ("Labels", SemSegLoader), ("MLLabels", MultilabelStackedLoader)]
198+
for i, (ln, loader, writer) in enumerate(
199+
[
200+
("Images", ImageLoader, ImageWriter),
201+
("Labels", SemSegLoader, SemSegWriter),
202+
("MLLabels", MultilabelStackedLoader, MultilabelStackedWriter),
203+
]
193204
):
194205
layer = cm.layer(ln)
195206

196207
dl = layer.data_loader()
197208
assert isinstance(dl, loader)
209+
df = layer.data_writer()
210+
assert isinstance(df, writer)
198211

199212
l_conf = layer.config()
200213
l_conf["pattern"] = simple_config["layers"][i]["pattern"]

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ passenv =
3030
PYVISTA_OFF_SCREEN
3131
extras =
3232
testing
33-
commands = pytest -v --color=yes --cov=data_io --cov-report=xml
33+
commands = pytest -v --color=yes --cov=vidata --cov-report=xml

0 commit comments

Comments
 (0)