Skip to content

Commit 0b77504

Browse files
refactor: refactor job runner to initialize post-processing tasks with job context
1 parent 9012d7f commit 0b77504

File tree

3 files changed

+25
-47
lines changed

3 files changed

+25
-47
lines changed

ami/jobs/models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -662,13 +662,16 @@ def run(cls, job: "Job"):
662662
params = job.params or {}
663663
task_key: str = params.get("task", "")
664664
config = params.get("config", {})
665-
job.logger.info(f"Post-processing task: {task_key} with params: {job.params}")
666-
665+
job.logger.info(f"Post-processing task: {task_key} with params: {config}")
666+
# Get the registered task class
667667
task_cls = get_postprocessing_task(task_key)
668668
if not task_cls:
669669
raise ValueError(f"Unknown post-processing task '{task_key}'")
670670

671+
# Instantiate the task with job context and config
671672
task = task_cls(job=job, **config)
673+
674+
# Run the task
672675
task.run()
673676
job.progress.update_stage(cls.key, status=JobState.SUCCESS, progress=1)
674677
job.finished_at = datetime.datetime.now()

ami/ml/post_processing/base.py

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import abc
44
import logging
5-
from typing import Any
5+
from typing import Any, Optional
66

77
from ami.jobs.models import Job
88
from ami.ml.models import Algorithm
@@ -39,54 +39,34 @@ class BasePostProcessingTask(abc.ABC):
3939
Abstract base class for all post-processing tasks.
4040
"""
4141

42-
# Each task must override these
4342
key: str = ""
4443
name: str = ""
4544

4645
def __init__(
4746
self,
48-
job: Job | None = None,
49-
logger: logging.Logger | None = None,
47+
job: Optional["Job"] = None,
48+
task_logger: logging.Logger | None = None,
5049
**config: Any,
5150
):
52-
self.job = job
53-
self.config = config
54-
# Choose the right logger
55-
if logger is not None:
56-
self.logger = logger
57-
elif job is not None:
58-
self.logger = job.logger
59-
else:
60-
self.logger = logging.getLogger(f"ami.post_processing.{self.key}")
61-
62-
algorithm, _ = Algorithm.objects.get_or_create(
63-
name=self.__class__.__name__,
64-
defaults={
65-
"description": f"Post-processing task: {self.key}",
66-
"task_type": AlgorithmTaskType.POST_PROCESSING.value,
67-
},
68-
)
69-
self.algorithm: Algorithm = algorithm
70-
71-
self.logger.info(f"Initialized {self.__class__.__name__} with config={self.config}, job={job}")
72-
73-
def update_progress(self, progress: float):
7451
"""
75-
Update progress if job is present, otherwise just log.
52+
Initialize task with optional job and logger context.
7653
"""
54+
self.job = job
55+
self.config: dict[str, Any] = config
7756

78-
if self.job:
79-
self.job.progress.update_stage(self.job.job_type_key, progress=progress)
80-
self.job.save(update_fields=["progress"])
81-
57+
if job:
58+
self.logger = job.logger
59+
elif task_logger:
60+
self.logger = task_logger
8261
else:
83-
# No job object — fallback to plain logging
84-
self.logger.info(f"[{self.name}] Progress {progress:.0%}")
62+
self.logger = logging.getLogger(f"ami.post_processing.{self.key}")
63+
self.log_config()
8564

8665
@abc.abstractmethod
8766
def run(self) -> None:
88-
"""
89-
Run the task logic.
90-
Must be implemented by subclasses.
91-
"""
92-
raise NotImplementedError("BasePostProcessingTask subclasses must implement run()")
67+
"""Run the task logic. Must be implemented by subclasses."""
68+
raise NotImplementedError("Subclasses must implement run()")
69+
70+
def log_config(self):
71+
"""Helper to log the task configuration at start."""
72+
self.logger.info(f"Running task {self.name} ({self.key}) with config: {self.config}")

ami/ml/post_processing/small_size_filter.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def run(self) -> None:
3131

3232
try:
3333
collection = SourceImageCollection.objects.get(pk=collection_id)
34-
self.logger.info(f"Loaded SourceImageCollection {collection_id} (Project={collection.project})")
34+
self.logger.info(f"Loaded SourceImageCollection {collection_id} " f"(Project={collection.project})")
3535
except SourceImageCollection.DoesNotExist:
3636
msg = f"SourceImageCollection {collection_id} not found"
3737
self.logger.error(msg)
@@ -85,11 +85,6 @@ def run(self) -> None:
8585
comment=f"Auto-set by {self.name} post-processing task",
8686
)
8787
modified += 1
88-
self.logger.info(f"Detection {det.pk}: marked as 'Not identifiable'")
89-
90-
# Update progress every 10 detections
91-
if i % 10 == 0 or i == total:
92-
progress = i / total if total > 0 else 1.0
93-
self.update_progress(progress)
88+
self.logger.debug(f"Detection {det.pk}: marked as 'Not identifiable'")
9489

9590
self.logger.info(f"=== Completed {self.name}: {modified}/{total} detections modified ===")

0 commit comments

Comments
 (0)