Skip to content

Commit 83d0194

Browse files
committed
Task Manager Registry
1 parent 54fbedb commit 83d0194

File tree

6 files changed

+35
-7
lines changed

6 files changed

+35
-7
lines changed

src/vidata/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,22 @@
33
# Registry imports
44
from .config_manager import ConfigManager, LayerConfigManager
55
from .file_manager import FileManager, FileManagerStacked
6-
from .registry import LOADER_REGISTRY, WRITER_REGISTRY, register_loader, register_writer
6+
from .registry import (
7+
LOADER_REGISTRY,
8+
TASK_REGISTRY,
9+
WRITER_REGISTRY,
10+
register_loader,
11+
register_task,
12+
register_writer,
13+
)
714

815
__all__ = (
916
"register_loader",
1017
"register_writer",
18+
"register_task",
1119
"LOADER_REGISTRY",
1220
"WRITER_REGISTRY",
21+
"TASK_REGISTRY",
1322
"FileManager",
1423
"FileManagerStacked",
1524
"ConfigManager",

src/vidata/config_manager.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
MultilabelStackedLoader,
1313
SemSegLoader,
1414
)
15+
from vidata.registry import TASK_REGISTRY
1516
from vidata.task_manager import (
16-
MultiLabelSegmentationManager,
17-
SemanticSegmentationManager,
1817
TaskManager,
1918
)
2019
from vidata.writers import (
@@ -313,10 +312,8 @@ def data_writer(self) -> BaseWriter:
313312
return writer_cls(**args)
314313

315314
def task_manager(self) -> TaskManager:
316-
if self.type.lower() == "semseg":
317-
return SemanticSegmentationManager()
318-
elif self.type.lower() == "multilabel":
319-
return MultiLabelSegmentationManager()
315+
if self.type.lower() in TASK_REGISTRY:
316+
return TASK_REGISTRY[self.type.lower()]
320317
else:
321318
raise ValueError(f"No Task manager defined for layer {self.name} and type {self.type}")
322319

src/vidata/registry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
lambda: defaultdict(dict)
1414
)
1515

16+
TASK_REGISTRY: dict[str, Any] = {}
17+
1618

1719
def register_loader(target: Target, *dtypes: str, backend: str = "default") -> Callable:
1820
"""
@@ -52,6 +54,19 @@ def decorator(func: Callable) -> Callable:
5254
return decorator
5355

5456

57+
def register_task(name: str):
58+
"""Register a task class under a string identifier."""
59+
60+
def decorator(cls):
61+
# if name in TASK_REGISTRY:
62+
# raise ValueError(f"Task '{name}' already registered.")
63+
TASK_REGISTRY[name] = cls
64+
return cls
65+
66+
return decorator
67+
68+
5569
# --- Trigger backend imports ---
5670
# This must come LAST so the above decorators exist before data_io modules import them
5771
import vidata.io # noqa
72+
import vidata.task_manager # noqa

src/vidata/task_manager/image_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22

3+
from vidata.registry import register_task
34

5+
6+
@register_task("image")
47
class ImageManager:
58
@staticmethod
69
def random(size: tuple[int, ...], dtype="float") -> np.ndarray:

src/vidata/task_manager/multilabel_segmentation_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
22

3+
from vidata.registry import register_task
34
from vidata.task_manager.task_manager import TaskManager
45

56

7+
@register_task("multilabel")
68
class MultiLabelSegmentationManager(TaskManager):
79
@staticmethod
810
def random(size: tuple[int, ...], num_classes: int) -> np.ndarray:

src/vidata/task_manager/semantic_segmentation_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
22

3+
from vidata.registry import register_task
34
from vidata.task_manager.task_manager import TaskManager
45

56

7+
@register_task("semseg")
68
class SemanticSegmentationManager(TaskManager):
79
@staticmethod
810
def random(size: tuple[int, ...], num_classes: int) -> np.ndarray:

0 commit comments

Comments
 (0)