|
18 | 18 | from __future__ import annotations |
19 | 19 |
|
20 | 20 | import datetime |
21 | | -from collections.abc import Iterable, Sequence |
22 | 21 | from typing import TYPE_CHECKING, Any, Callable |
23 | 22 |
|
24 | | -from sqlalchemy import select |
25 | | - |
26 | 23 | from airflow.configuration import conf |
27 | 24 | from airflow.sdk.definitions._internal.abstractoperator import ( |
28 | 25 | AbstractOperator as TaskSDKAbstractOperator, |
29 | 26 | NotMapped as NotMapped, # Re-export this for compat |
30 | 27 | ) |
31 | 28 | from airflow.sdk.definitions.context import Context |
32 | | -from airflow.utils.db import exists_query |
33 | 29 | from airflow.utils.log.logging_mixin import LoggingMixin |
34 | | -from airflow.utils.sqlalchemy import with_row_locks |
35 | | -from airflow.utils.state import State, TaskInstanceState |
36 | 30 | from airflow.utils.trigger_rule import TriggerRule |
37 | 31 | from airflow.utils.weight_rule import db_safe_priority |
38 | 32 |
|
39 | 33 | if TYPE_CHECKING: |
40 | 34 | from sqlalchemy.orm import Session |
41 | 35 |
|
42 | | - from airflow.models.dag import DAG as SchedulerDAG |
43 | | - from airflow.models.taskinstance import TaskInstance |
44 | 36 | from airflow.sdk.definitions.baseoperator import BaseOperator |
45 | 37 | from airflow.task.priority_strategy import PriorityWeightStrategy |
46 | 38 | from airflow.triggers.base import StartTriggerArgs |
@@ -152,138 +144,3 @@ def priority_weight_total(self) -> int: |
152 | 144 | for task_id in self.get_flat_relative_ids(upstream=upstream) |
153 | 145 | ) |
154 | 146 | ) |
155 | | - |
156 | | - def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]: |
157 | | - """ |
158 | | - Create the mapped task instances for mapped task. |
159 | | -
|
160 | | - :raise NotMapped: If this task does not need expansion. |
161 | | - :return: The newly created mapped task instances (if any) in ascending |
162 | | - order by map index, and the maximum map index value. |
163 | | - """ |
164 | | - from sqlalchemy import func, or_ |
165 | | - |
166 | | - from airflow.models.taskinstance import TaskInstance |
167 | | - from airflow.sdk.definitions.baseoperator import BaseOperator |
168 | | - from airflow.sdk.definitions.mappedoperator import MappedOperator |
169 | | - from airflow.settings import task_instance_mutation_hook |
170 | | - |
171 | | - if not isinstance(self, (BaseOperator, MappedOperator)): |
172 | | - raise RuntimeError( |
173 | | - f"cannot expand unrecognized operator type {type(self).__module__}.{type(self).__name__}" |
174 | | - ) |
175 | | - |
176 | | - from airflow.models.baseoperator import BaseOperator as DBBaseOperator |
177 | | - from airflow.models.expandinput import NotFullyPopulated |
178 | | - |
179 | | - try: |
180 | | - total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session) |
181 | | - except NotFullyPopulated as e: |
182 | | - # It's possible that the upstream tasks are not yet done, but we |
183 | | - # don't have upstream of upstreams in partial DAGs (possible in the |
184 | | - # mini-scheduler), so we ignore this exception. |
185 | | - if not self.dag or not self.dag.partial: |
186 | | - self.log.error( |
187 | | - "Cannot expand %r for run %s; missing upstream values: %s", |
188 | | - self, |
189 | | - run_id, |
190 | | - sorted(e.missing), |
191 | | - ) |
192 | | - total_length = None |
193 | | - |
194 | | - state: TaskInstanceState | None = None |
195 | | - unmapped_ti: TaskInstance | None = session.scalars( |
196 | | - select(TaskInstance).where( |
197 | | - TaskInstance.dag_id == self.dag_id, |
198 | | - TaskInstance.task_id == self.task_id, |
199 | | - TaskInstance.run_id == run_id, |
200 | | - TaskInstance.map_index == -1, |
201 | | - or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), |
202 | | - ) |
203 | | - ).one_or_none() |
204 | | - |
205 | | - all_expanded_tis: list[TaskInstance] = [] |
206 | | - |
207 | | - if unmapped_ti: |
208 | | - if TYPE_CHECKING: |
209 | | - assert self.dag is None or isinstance(self.dag, SchedulerDAG) |
210 | | - |
211 | | - # The unmapped task instance still exists and is unfinished, i.e. we |
212 | | - # haven't tried to run it before. |
213 | | - if total_length is None: |
214 | | - # If the DAG is partial, it's likely that the upstream tasks |
215 | | - # are not done yet, so the task can't fail yet. |
216 | | - if not self.dag or not self.dag.partial: |
217 | | - unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED |
218 | | - elif total_length < 1: |
219 | | - # If the upstream maps this to a zero-length value, simply mark |
220 | | - # the unmapped task instance as SKIPPED (if needed). |
221 | | - self.log.info( |
222 | | - "Marking %s as SKIPPED since the map has %d values to expand", |
223 | | - unmapped_ti, |
224 | | - total_length, |
225 | | - ) |
226 | | - unmapped_ti.state = TaskInstanceState.SKIPPED |
227 | | - else: |
228 | | - zero_index_ti_exists = exists_query( |
229 | | - TaskInstance.dag_id == self.dag_id, |
230 | | - TaskInstance.task_id == self.task_id, |
231 | | - TaskInstance.run_id == run_id, |
232 | | - TaskInstance.map_index == 0, |
233 | | - session=session, |
234 | | - ) |
235 | | - if not zero_index_ti_exists: |
236 | | - # Otherwise convert this into the first mapped index, and create |
237 | | - # TaskInstance for other indexes. |
238 | | - unmapped_ti.map_index = 0 |
239 | | - self.log.debug("Updated in place to become %s", unmapped_ti) |
240 | | - all_expanded_tis.append(unmapped_ti) |
241 | | - # execute hook for task instance map index 0 |
242 | | - task_instance_mutation_hook(unmapped_ti) |
243 | | - session.flush() |
244 | | - else: |
245 | | - self.log.debug("Deleting the original task instance: %s", unmapped_ti) |
246 | | - session.delete(unmapped_ti) |
247 | | - state = unmapped_ti.state |
248 | | - |
249 | | - if total_length is None or total_length < 1: |
250 | | - # Nothing to fixup. |
251 | | - indexes_to_map: Iterable[int] = () |
252 | | - else: |
253 | | - # Only create "missing" ones. |
254 | | - current_max_mapping = session.scalar( |
255 | | - select(func.max(TaskInstance.map_index)).where( |
256 | | - TaskInstance.dag_id == self.dag_id, |
257 | | - TaskInstance.task_id == self.task_id, |
258 | | - TaskInstance.run_id == run_id, |
259 | | - ) |
260 | | - ) |
261 | | - indexes_to_map = range(current_max_mapping + 1, total_length) |
262 | | - |
263 | | - for index in indexes_to_map: |
264 | | - # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. |
265 | | - ti = TaskInstance(self, run_id=run_id, map_index=index, state=state) |
266 | | - self.log.debug("Expanding TIs upserted %s", ti) |
267 | | - task_instance_mutation_hook(ti) |
268 | | - ti = session.merge(ti) |
269 | | - ti.refresh_from_task(self) # session.merge() loses task information. |
270 | | - all_expanded_tis.append(ti) |
271 | | - |
272 | | - # Coerce the None case to 0 -- these two are almost treated identically, |
273 | | - # except the unmapped ti (if exists) is marked to different states. |
274 | | - total_expanded_ti_count = total_length or 0 |
275 | | - |
276 | | - # Any (old) task instances with inapplicable indexes (>= the total |
277 | | - # number we need) are set to "REMOVED". |
278 | | - query = select(TaskInstance).where( |
279 | | - TaskInstance.dag_id == self.dag_id, |
280 | | - TaskInstance.task_id == self.task_id, |
281 | | - TaskInstance.run_id == run_id, |
282 | | - TaskInstance.map_index >= total_expanded_ti_count, |
283 | | - ) |
284 | | - query = with_row_locks(query, of=TaskInstance, session=session, skip_locked=True) |
285 | | - to_update = session.scalars(query) |
286 | | - for ti in to_update: |
287 | | - ti.state = TaskInstanceState.REMOVED |
288 | | - session.flush() |
289 | | - return all_expanded_tis, total_expanded_ti_count - 1 |
0 commit comments