|
6 | 6 | from django.db.utils import IntegrityError |
7 | 7 | from django.core.management import call_command |
8 | 8 | from domains.etl.models import Task, TaskRun |
| 9 | +from domains.sta.models import Datastream |
9 | 10 | from hydroserverpy.etl.hydroserver import build_hydroserver_pipeline |
10 | 11 | from hydroserverpy.etl.exceptions import ETLError |
11 | 12 | from .internal import HydroServerInternalExtractor, HydroServerInternalTransformer, HydroServerInternalLoader |
12 | 13 |
|
13 | 14 |
|
| 15 | +def _raise_aggregation_error(message: str): |
| 16 | + raise ETLError(message) |
| 17 | + |
| 18 | + |
| 19 | +def _validate_aggregation_task_runtime(task: Task): |
| 20 | + mappings = list(task.mappings.all()) |
| 21 | + if not mappings: |
| 22 | + _raise_aggregation_error("Aggregation tasks must include at least one mapping.") |
| 23 | + |
| 24 | + datastream_ids = set() |
| 25 | + for mapping in mappings: |
| 26 | + try: |
| 27 | + datastream_ids.add(UUID(str(mapping.source_identifier))) |
| 28 | + except (TypeError, ValueError): |
| 29 | + _raise_aggregation_error( |
| 30 | + "Aggregation mapping sourceIdentifier must be a valid datastream UUID." |
| 31 | + ) |
| 32 | + |
| 33 | + paths = list(mapping.paths.all()) |
| 34 | + if len(paths) != 1: |
| 35 | + _raise_aggregation_error( |
| 36 | + "Aggregation mappings must include exactly one target path per source." |
| 37 | + ) |
| 38 | + |
| 39 | + path = paths[0] |
| 40 | + transformations = path.data_transformations or [] |
| 41 | + if ( |
| 42 | + not isinstance(transformations, list) |
| 43 | + or len(transformations) != 1 |
| 44 | + or not isinstance(transformations[0], dict) |
| 45 | + or transformations[0].get("type") != "aggregation" |
| 46 | + ): |
| 47 | + _raise_aggregation_error( |
| 48 | + "Aggregation mappings must include exactly one aggregation transformation per path." |
| 49 | + ) |
| 50 | + |
| 51 | + try: |
| 52 | + datastream_ids.add(UUID(str(path.target_identifier))) |
| 53 | + except (TypeError, ValueError): |
| 54 | + _raise_aggregation_error( |
| 55 | + "Aggregation mapping targetIdentifier must be a valid datastream UUID." |
| 56 | + ) |
| 57 | + |
| 58 | + existing_datastream_ids = set( |
| 59 | + Datastream.objects.filter( |
| 60 | + thing__workspace_id=task.workspace_id, |
| 61 | + id__in=datastream_ids, |
| 62 | + ).values_list("id", flat=True) |
| 63 | + ) |
| 64 | + if datastream_ids - existing_datastream_ids: |
| 65 | + _raise_aggregation_error( |
| 66 | + "Aggregation source and target datastreams must exist in the task workspace." |
| 67 | + ) |
| 68 | + |
| 69 | + |
14 | 70 | @shared_task(bind=True, expires=10) |
15 | 71 | def run_etl_task(self, task_id: str): |
16 | 72 | """ |
@@ -39,6 +95,7 @@ def run_etl_task(self, task_id: str): |
39 | 95 |
|
40 | 96 | try: |
41 | 97 | if task.task_type == "Aggregation": |
| 98 | + _validate_aggregation_task_runtime(task) |
42 | 99 | etl_classes = { |
43 | 100 | "extractor_cls": HydroServerInternalExtractor, |
44 | 101 | "transformer_cls": HydroServerInternalTransformer, |
@@ -127,11 +184,15 @@ def mark_etl_task_started(sender, task_id, kwargs, **extra): |
127 | 184 | return |
128 | 185 |
|
129 | 186 | try: |
130 | | - TaskRun.objects.create( |
| 187 | + TaskRun.objects.update_or_create( |
131 | 188 | id=task_id, |
132 | | - task_id=kwargs["task_id"], |
133 | | - status="RUNNING", |
134 | | - started_at=timezone.now(), |
| 189 | + defaults={ |
| 190 | + "task_id": kwargs["task_id"], |
| 191 | + "status": "RUNNING", |
| 192 | + "started_at": timezone.now(), |
| 193 | + "finished_at": None, |
| 194 | + "result": None, |
| 195 | + }, |
135 | 196 | ) |
136 | 197 | except IntegrityError: |
137 | 198 | return |
@@ -198,13 +259,16 @@ def mark_etl_task_failure(sender, task_id, einfo, exception, **extra): |
198 | 259 | except TaskRun.DoesNotExist: |
199 | 260 | return |
200 | 261 |
|
201 | | - task_run.status = "FAILURE" |
202 | | - task_run.finished_at = timezone.now() |
203 | | - task_run.result = { |
| 262 | + result = { |
| 263 | + "message": str(exception), |
204 | 264 | "error": str(exception), |
205 | 265 | "traceback": einfo.traceback, |
206 | | - **(getattr(exception, "results", None) or {}), |
207 | 266 | } |
| 267 | + result.update(getattr(exception, "results", None) or {}) |
| 268 | + |
| 269 | + task_run.status = "FAILURE" |
| 270 | + task_run.finished_at = timezone.now() |
| 271 | + task_run.result = result |
208 | 272 |
|
209 | 273 | task_run.save(update_fields=["status", "finished_at", "result"]) |
210 | 274 |
|
|
0 commit comments