Skip to content

Commit 58f2850

Browse files
refactor: use static registry for post-processing task lookup
1 parent ade9a51 commit 58f2850

File tree

4 files changed

+19
-33
lines changed

4 files changed

+19
-33
lines changed

ami/jobs/models.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ami.jobs.tasks import run_job
1919
from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection
2020
from ami.ml.models import Pipeline
21+
from ami.ml.post_processing.registry import get_postprocessing_task
2122
from ami.utils.schemas import OrderedEnum
2223

2324
logger = logging.getLogger(__name__)
@@ -651,9 +652,6 @@ class PostProcessingJob(JobType):
651652

652653
@classmethod
653654
def run(cls, job: "Job"):
654-
import ami.ml.post_processing # noqa F401
655-
from ami.ml.post_processing.base import get_postprocessing_task
656-
657655
job.progress.add_stage(cls.name, key=cls.key)
658656
job.update_status(JobState.STARTED)
659657
job.started_at = datetime.datetime.now()
@@ -664,7 +662,7 @@ def run(cls, job: "Job"):
664662
config = params.get("config", {})
665663
job.logger.info(f"Post-processing task: {task_key} with params: {job.params}")
666664

667-
task_cls = get_postprocessing_task(task_key)
665+
task_cls = get_postprocessing_task(key=task_key)
668666
if not task_cls:
669667
raise ValueError(f"Unknown post-processing task '{task_key}'")
670668

ami/ml/post_processing/base.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,13 @@
11
import abc
22
import logging
3-
from typing import Any
3+
import typing
4+
from typing import Any, Optional
45

5-
from ami.jobs.models import Job
66
from ami.ml.models import Algorithm
77
from ami.ml.models.algorithm import AlgorithmTaskType
88

9-
# Registry of available post-processing tasks
10-
POSTPROCESSING_TASKS: dict[str, type["BasePostProcessingTask"]] = {}
11-
12-
13-
def register_postprocessing_task(task_cls: type["BasePostProcessingTask"]):
14-
"""
15-
Decorator to register a post-processing task in the global registry.
16-
Each task must define a unique `key`.
17-
Ensures an Algorithm entry exists for this task.
18-
"""
19-
if not hasattr(task_cls, "key") or not task_cls.key:
20-
raise ValueError(f"Task {task_cls.__name__} missing required 'key' attribute")
21-
22-
# Register the task
23-
POSTPROCESSING_TASKS[task_cls.key] = task_cls
24-
return task_cls
25-
26-
27-
def get_postprocessing_task(name: str) -> type["BasePostProcessingTask"] | None:
28-
"""
29-
Get a task class by its registry key.
30-
Returns None if not found.
31-
"""
32-
return POSTPROCESSING_TASKS.get(name)
9+
if typing.TYPE_CHECKING:
10+
from ami.jobs.models import Job
3311

3412

3513
class BasePostProcessingTask(abc.ABC):
@@ -43,7 +21,7 @@ class BasePostProcessingTask(abc.ABC):
4321

4422
def __init__(
4523
self,
46-
job: Job | None = None,
24+
job: Optional["Job"] = None,
4725
logger: logging.Logger | None = None,
4826
**config: Any,
4927
):

ami/ml/post_processing/registry.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Registry of available post-processing tasks
2+
from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask
3+
4+
POSTPROCESSING_TASKS = {
5+
SmallSizeFilterTask.key: SmallSizeFilterTask,
6+
}
7+
8+
9+
def get_postprocessing_task(key: str):
10+
"""Return a post-processing task class by key."""
11+
return POSTPROCESSING_TASKS.get(key)

ami/ml/post_processing/small_size_filter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
from django.utils import timezone
33

44
from ami.main.models import Detection, SourceImageCollection, Taxon, TaxonRank
5-
from ami.ml.post_processing.base import BasePostProcessingTask, register_postprocessing_task
5+
from ami.ml.post_processing.base import BasePostProcessingTask
66

77

8-
@register_postprocessing_task
98
class SmallSizeFilterTask(BasePostProcessingTask):
109
key = "small_size_filter"
1110
name = "Small size filter"

0 commit comments

Comments
 (0)