|
2 | 2 |
|
3 | 3 | import abc |
4 | 4 | import logging |
5 | | -from typing import Any |
| 5 | +from typing import Any, Optional |
6 | 6 |
|
7 | 7 | from ami.jobs.models import Job |
8 | 8 | from ami.ml.models import Algorithm |
@@ -39,54 +39,34 @@ class BasePostProcessingTask(abc.ABC): |
39 | 39 | Abstract base class for all post-processing tasks. |
40 | 40 | """ |
41 | 41 |
|
42 | | - # Each task must override these |
43 | 42 | key: str = "" |
44 | 43 | name: str = "" |
45 | 44 |
|
46 | 45 | def __init__( |
47 | 46 | self, |
48 | | - job: Job | None = None, |
49 | | - logger: logging.Logger | None = None, |
| 47 | + job: Optional["Job"] = None, |
| 48 | + task_logger: logging.Logger | None = None, |
50 | 49 | **config: Any, |
51 | 50 | ): |
52 | | - self.job = job |
53 | | - self.config = config |
54 | | - # Choose the right logger |
55 | | - if logger is not None: |
56 | | - self.logger = logger |
57 | | - elif job is not None: |
58 | | - self.logger = job.logger |
59 | | - else: |
60 | | - self.logger = logging.getLogger(f"ami.post_processing.{self.key}") |
61 | | - |
62 | | - algorithm, _ = Algorithm.objects.get_or_create( |
63 | | - name=self.__class__.__name__, |
64 | | - defaults={ |
65 | | - "description": f"Post-processing task: {self.key}", |
66 | | - "task_type": AlgorithmTaskType.POST_PROCESSING.value, |
67 | | - }, |
68 | | - ) |
69 | | - self.algorithm: Algorithm = algorithm |
70 | | - |
71 | | - self.logger.info(f"Initialized {self.__class__.__name__} with config={self.config}, job={job}") |
72 | | - |
73 | | - def update_progress(self, progress: float): |
74 | 51 | """ |
75 | | - Update progress if job is present, otherwise just log. |
| 52 | + Initialize task with optional job and logger context. |
76 | 53 | """ |
| 54 | + self.job = job |
| 55 | + self.config: dict[str, Any] = config |
77 | 56 |
|
78 | | - if self.job: |
79 | | - self.job.progress.update_stage(self.job.job_type_key, progress=progress) |
80 | | - self.job.save(update_fields=["progress"]) |
81 | | - |
| 57 | + if job: |
| 58 | + self.logger = job.logger |
| 59 | + elif task_logger: |
| 60 | + self.logger = task_logger |
82 | 61 | else: |
83 | | - # No job object — fallback to plain logging |
84 | | - self.logger.info(f"[{self.name}] Progress {progress:.0%}") |
| 62 | + self.logger = logging.getLogger(f"ami.post_processing.{self.key}") |
| 63 | + self.log_config() |
85 | 64 |
|
86 | 65 | @abc.abstractmethod |
87 | 66 | def run(self) -> None: |
88 | | - """ |
89 | | - Run the task logic. |
90 | | - Must be implemented by subclasses. |
91 | | - """ |
92 | | - raise NotImplementedError("BasePostProcessingTask subclasses must implement run()") |
| 67 | + """Run the task logic. Must be implemented by subclasses.""" |
| 68 | + raise NotImplementedError("Subclasses must implement run()") |
| 69 | + |
| 70 | + def log_config(self): |
| 71 | + """Helper to log the task configuration at start.""" |
| 72 | + self.logger.info(f"Running task {self.name} ({self.key}) with config: {self.config}") |
0 commit comments