1616# under the License.
1717from __future__ import annotations
1818
19- from typing import TYPE_CHECKING , cast
19+ from typing import TYPE_CHECKING , Any
2020
21- from airflow .sdk .definitions ._internal .abstractoperator import AbstractOperator
22- from airflow .serialization .definitions .baseoperator import SerializedBaseOperator
23- from airflow .serialization .definitions .dag import SerializedDAG
24- from airflow .serialization .definitions .mappedoperator import SerializedMappedOperator
21+ from airflow .serialization .definitions .taskgroup import SerializedTaskGroup
22+
23+ # Also support SDK types if possible.
24+ try :
25+ from airflow .sdk import TaskGroup
26+ except ImportError :
27+ TaskGroup = SerializedTaskGroup # type: ignore[misc]
2528
2629if TYPE_CHECKING :
27- from airflow .sdk import DAG
2830 from airflow .serialization .definitions .dag import SerializedDAG
2931 from airflow .serialization .definitions .mappedoperator import Operator
32+ from airflow .serialization .definitions .node import DAGNode
3033
3134
32- def dag_edges (dag : DAG | SerializedDAG ):
35+ def dag_edges (dag : SerializedDAG ):
3336 """
3437 Create the list of edges needed to construct the Graph view.
3538
@@ -62,9 +65,10 @@ def dag_edges(dag: DAG | SerializedDAG):
6265
6366 task_group_map = dag .task_group .get_task_group_dict ()
6467
65- def collect_edges (task_group ) :
68+ def collect_edges (task_group : DAGNode ) -> None :
6669 """Update edges_to_add and edges_to_skip according to TaskGroups."""
67- if isinstance (task_group , (AbstractOperator , SerializedBaseOperator , SerializedMappedOperator )):
70+ child : DAGNode
71+ if not isinstance (task_group , (TaskGroup , SerializedTaskGroup )):
6872 return
6973
7074 for target_id in task_group .downstream_group_ids :
@@ -111,9 +115,7 @@ def collect_edges(task_group):
111115 edges = set ()
112116 setup_teardown_edges = set ()
113117
114- # TODO (GH-52141): 'roots' in scheduler needs to return scheduler types
115- # instead, but currently it inherits SDK's DAG.
116- tasks_to_trace = cast ("list[Operator]" , dag .roots )
118+ tasks_to_trace = dag .roots
117119 while tasks_to_trace :
118120 tasks_to_trace_next : list [Operator ] = []
119121 for task in tasks_to_trace :
@@ -130,7 +132,7 @@ def collect_edges(task_group):
130132 # Build result dicts with the two ends of the edge, plus any extra metadata
131133 # if we have it.
132134 for source_id , target_id in sorted (edges .union (edges_to_add ) - edges_to_skip ):
133- record = {"source_id" : source_id , "target_id" : target_id }
135+ record : dict [ str , Any ] = {"source_id" : source_id , "target_id" : target_id }
134136 label = dag .get_edge_info (source_id , target_id ).get ("label" )
135137 if (source_id , target_id ) in setup_teardown_edges :
136138 record ["is_setup_teardown" ] = True
0 commit comments