Skip to content

Commit 9023c78

Browse files
committed
Implement AiiDA provenance tracking for Airflow via XCom backend and listeners
Add comprehensive AiiDA provenance graph creation for Airflow DAG executions: XCom Backend (src/airflow_provider_aiida/xcom/backend.py): - Custom XCom backend that creates AiiDA nodes during serialization/deserialization - WorkChainNode represents entire DAG run (dag_id + run_id) - CalcJobNode represents individual tasks (task_id + run_id + map_index) - Data nodes store XCom values as AiiDA-typed data - Establishes provenance links: * CALL_CALC: WorkChain → CalcJob (workflow calls calculation) * CREATE: CalcJob → Data (task produces output) * INPUT_CALC: Data → CalcJob (data flows to consuming task) - Smart link labeling: * For PythonOperator: extracts parameter names via inspect.signature() * For other operators: deterministic hash based on producer task info - Handles duplicate link prevention and stored node constraints - Monkey-patches link validation for stored nodes when necessary Provenance Listener (src/airflow_provider_aiida/plugins/provenance_listener.py): - Airflow listener plugin that updates AiiDA node states in real-time - Hooks into DAG run lifecycle (success, failure, running) - Hooks into task instance lifecycle (success, failure, running) - Maps Airflow states to AiiDA ProcessStates: * QUEUED/SCHEDULED → CREATED * RUNNING → RUNNING * SUCCESS → FINISHED * FAILED → EXCEPTED * SKIPPED → KILLED - Creates nodes proactively if they don't exist (handles tasks starting before XCom) - Registered as ProvenanceListenerPlugin for automatic discovery Common utilities to handle aiida nodes(src/airflow_provider_aiida/common/utils.py): - _get_or_create_workchain_node: Query by unique_id or create new WorkChainNode - _get_or_create_calcjob_node: Query by unique_id or create new CalcJobNode - _sanitize_link_label: Ensure AiiDA-compatible link labels (alphanumeric + underscore) - All new nodes initialized with ProcessState.CREATED Caveats: - When deserializing no information about the input key is given, so an educated guess has to be made which for the moment fails when maps are used - on_dag_run_running is not called in test run environment, therefore the workchain node is created in on_task_run_running function - Because we have no guarantee from airflow for the order of callbacks (executed by the task instance) and xcom backend (executed by the scheduler) we have to make logic redundant in the xcom backend and listeners Result: Complete AiiDA provenance graph mirroring Airflow DAG structure with real-time state synchronization and proper data lineage tracking.
1 parent 2100993 commit 9023c78

File tree

5 files changed

+1000
-4
lines changed

5 files changed

+1000
-4
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ provider_info = "airflow_provider_aiida.__init__:get_provider_info"
4343
[project.entry-points."aiida.dags"]
4444
aiida-standard = "airflow_provider_aiida.example_dags"
4545

46+
[project.entry-points."airflow.plugins"]
47+
aiida_dag_run_listener = "airflow_provider_aiida.plugins.provenance_listener:ProvenanceListenerPlugin"
48+
4649
[tool.hatch.version]
4750
path = "src/airflow_provider_aiida/__init__.py"
4851

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Common utilities for the AiiDA Airflow provider."""
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""Common ORM utilities for AiiDA nodes in the Airflow provider."""
2+
3+
import re
4+
import logging
5+
import os
6+
7+
logger = logging.getLogger(__name__)
8+
9+
# Add file handler for detailed ORM logging
10+
log_file = os.path.expanduser('~/airflow_provider_aiida_orm.log')
11+
file_handler = logging.FileHandler(log_file, mode='a')
12+
file_handler.setLevel(logging.DEBUG)
13+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
14+
file_handler.setFormatter(formatter)
15+
logger.addHandler(file_handler)
16+
# TODO for prototyping phase we put it on DEBUG
17+
logger.setLevel(logging.DEBUG)
18+
19+
logger.info(f"ORM logging initialized, writing to: {log_file}")
20+
21+
22+
def _sanitize_link_label(label: str) -> str:
23+
"""Sanitize a string to be a valid AiiDA link label.
24+
25+
AiiDA link labels must contain only alphanumeric characters and underscores.
26+
27+
:param label: The label to sanitize
28+
:return: Sanitized label with only valid characters
29+
"""
30+
# Replace any non-alphanumeric, non-underscore characters with underscores
31+
sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', label)
32+
# Ensure it doesn't start with a number (prepend 'x' if it does)
33+
if sanitized and sanitized[0].isdigit():
34+
sanitized = 'x' + sanitized
35+
return sanitized or 'result'
36+
37+
38+
def _get_workchain_node(dag_id: str, run_id: str):
39+
"""Get an existing WorkChainNode for an Airflow DAG run.
40+
41+
Returns None if no node exists.
42+
43+
:param dag_id: Airflow DAG ID
44+
:param run_id: Airflow run ID
45+
:return: WorkChainNode instance or None
46+
"""
47+
from aiida.orm import WorkChainNode, QueryBuilder
48+
49+
# Create a unique identifier for this DAG run
50+
unique_id = f"{dag_id}__{run_id}"
51+
52+
# Check if node already exists using the unique_id stored in extras
53+
qb = QueryBuilder()
54+
qb.append(WorkChainNode, filters={'extras.airflow_unique_id': unique_id})
55+
results = qb.all()
56+
57+
if results:
58+
existing_node = results[0][0]
59+
logger.debug(f"Found existing WorkChainNode: dag_id={dag_id}, run_id={run_id}, pk={existing_node.pk}")
60+
return existing_node
61+
62+
logger.debug(f"No WorkChainNode found: dag_id={dag_id}, run_id={run_id}")
63+
return None
64+
65+
def _get_or_create_workchain_node(dag_id: str, run_id: str):
66+
"""Get or create a WorkChainNode representing an Airflow DAG run.
67+
68+
Sets initial state to CREATED when creating a new node.
69+
70+
:param dag_id: Airflow DAG ID
71+
:param run_id: Airflow run ID
72+
:return: Tuple of (WorkChainNode instance, created: bool)
73+
created is True if a new node was created, False if existing node was found
74+
"""
75+
from aiida import load_profile
76+
from aiida.orm import WorkChainNode, QueryBuilder
77+
from plumpy.process_states import ProcessState
78+
load_profile()
79+
80+
# Create a unique identifier for this DAG run
81+
unique_id = f"{dag_id}__{run_id}"
82+
83+
# Check if node already exists using the unique_id stored in extras
84+
qb = QueryBuilder()
85+
qb.append(WorkChainNode, filters={'extras.airflow_unique_id': unique_id})
86+
results = qb.all()
87+
88+
if results:
89+
existing_node = results[0][0]
90+
logger.debug(f"Found existing WorkChainNode: dag_id={dag_id}, run_id={run_id}, pk={existing_node.pk}")
91+
return existing_node, False
92+
93+
# Create new WorkChainNode for the DAG run
94+
logger.debug(f"Creating NEW WorkChainNode: dag_id={dag_id}, run_id={run_id}, unique_id={unique_id}")
95+
workchain_node = WorkChainNode()
96+
workchain_node.set_process_label(f"{dag_id}[{run_id}]")
97+
workchain_node.label = dag_id
98+
workchain_node.description = f"Airflow DAG run: {dag_id}, run_id: {run_id}"
99+
workchain_node.base.extras.set('airflow_unique_id', unique_id)
100+
workchain_node.base.extras.set('airflow_dag_id', dag_id)
101+
workchain_node.base.extras.set('airflow_run_id', run_id)
102+
workchain_node.base.extras.set('airflow_xcom_backend', True)
103+
104+
# Set initial process state to CREATED
105+
workchain_node.set_process_state(ProcessState.CREATED)
106+
107+
logger.info(f"Created NEW WorkChainNode: dag_id={dag_id}, run_id={run_id}, node={workchain_node}")
108+
109+
return workchain_node, True
110+
111+
112+
def _get_or_create_calcjob_node(task_id: str, dag_id: str, run_id: str, map_index: int = -1):
113+
"""Get or create a CalcJobNode representing an Airflow task.
114+
115+
Sets initial state to CREATED when creating a new node.
116+
117+
:param task_id: Airflow task ID
118+
:param dag_id: Airflow DAG ID
119+
:param run_id: Airflow run ID
120+
:param map_index: Airflow map index for mapped tasks
121+
:return: CalcJobNode instance
122+
"""
123+
from aiida import load_profile
124+
from aiida.orm import CalcJobNode, QueryBuilder
125+
from plumpy.process_states import ProcessState
126+
load_profile()
127+
128+
# Create a unique identifier for this task execution (used in extras for lookup)
129+
unique_id = f"{dag_id}__{task_id}__{run_id}"
130+
if map_index >= 0:
131+
unique_id += f"__{map_index}"
132+
133+
# Check if node already exists using the unique_id stored in extras
134+
qb = QueryBuilder()
135+
qb.append(CalcJobNode, filters={'extras.airflow_unique_id': unique_id})
136+
results = qb.all()
137+
138+
if results:
139+
existing_node = results[0][0]
140+
logger.debug(f"Found existing CalcJobNode: dag_id={dag_id}, task_id={task_id}, run_id={run_id}, map_index={map_index}, pk={existing_node.pk}")
141+
return existing_node
142+
143+
# Create new CalcJobNode with task_id as label (more readable)
144+
logger.warning(f"Creating NEW CalcJobNode: dag_id={dag_id}, task_id={task_id}, run_id={run_id}, map_index={map_index}, unique_id={unique_id}")
145+
calc_node = CalcJobNode()
146+
# Use task_id as label for readability, with map_index if applicable
147+
calc_node.set_process_label(f"{task_id}[{map_index}]" if map_index >= 0 else task_id)
148+
calc_node.label = task_id
149+
calc_node.description = f"Airflow task: {task_id} from DAG: {dag_id}, run: {run_id}"
150+
calc_node.base.extras.set('airflow_unique_id', unique_id)
151+
calc_node.base.extras.set('airflow_dag_id', dag_id)
152+
calc_node.base.extras.set('airflow_task_id', task_id)
153+
calc_node.base.extras.set('airflow_run_id', run_id)
154+
calc_node.base.extras.set('airflow_map_index', map_index)
155+
calc_node.base.extras.set('airflow_xcom_backend', True)
156+
157+
# Set initial process state to CREATED
158+
calc_node.set_process_state(ProcessState.CREATED)
159+
160+
logger.info(f"Created NEW CalcJobNode: dag_id={dag_id}, task_id={task_id}, run_id={run_id}, map_index={map_index}, node={calc_node}")
161+
162+
return calc_node

0 commit comments

Comments
 (0)