Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/trigger_arithmetic_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
'code': code,
'x': Int(0),
'y': Int(1),
#'metadata': {'options': {'sleep': 5}}
'metadata': {'options': {'sleep': 500}}
}

dag_run = submit(ArithmeticAddCalculation, inputs)
Expand Down
2 changes: 2 additions & 0 deletions src/airflow_provider_aiida/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ def airdi():

# Register command groups
from airflow_provider_aiida.cli.cmd_presto import airdi_presto
from airflow_provider_aiida.cli.cmd_process import airdi_process
from airflow_provider_aiida.cli.cmd_services import airdi_services
from airflow_provider_aiida.cli.cmd_status import airdi_status
from airflow_provider_aiida.cli.cmd_airflow import airdi_airflow

airdi.add_command(airdi_presto)
airdi.add_command(airdi_process)
airdi.add_command(airdi_services)
airdi.add_command(airdi_status)
airdi.add_command(airdi_airflow)
185 changes: 185 additions & 0 deletions src/airflow_provider_aiida/cli/cmd_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Commands for managing AiiDA processes in Airflow."""
import click


@click.group('process')
def airdi_process():
"""
Manage AiiDA processes in Airflow.

Commands for pausing (marking as failed) and playing (clearing/resuming)
DAG runs associated with AiiDA process nodes.
"""
pass


@airdi_process.command('stop')
@click.argument('pk', type=int)
def process_pause(pk):
"""
Stop the Airflow DAG run for an AiiDA process.

This marks the DAG run associated with the process node as failed,
effectively pausing its execution in Airflow.

Arguments:
PK: The primary key (PK) of the AiiDA process node

Examples:
airdi process pause 123
airdi process pause 456
"""
from airflow_provider_aiida.aiida_core import load_profile
from airflow_provider_aiida.utils.airflow_control import mark_aiida_process_dag_run_failed, get_dag_run_id
from aiida.orm import load_node

try:
# Load profile
load_profile()

# Load the process node
click.echo(f"Loading process node {pk}...")
try:
node = load_node(pk)
except Exception as e:
click.secho(f"✗ Error: Could not load node with PK {pk}: {e}", fg='red', err=True)
raise click.Abort()

# Get the DAG run ID from node attributes

dag_run_id = get_dag_run_id(node)
if dag_run_id is None:
click.secho(
f"✗ Error: Node {pk} does not have a DAG run ID attribute. "
f"This node may not be managed by Airflow.",
fg='red',
err=True
)
raise click.Abort()

click.echo(f"DAG run ID: {dag_run_id}")

# Get process class name for DAG ID
process_type = node.process_class
if not process_type:
click.secho(f"✗ Error: Node {pk} does not have a process type", fg='red', err=True)
raise click.Abort()

click.echo(f"Process type: {process_type}")
click.echo(f"\nMarking DAG run as failed...")

# Mark the DAG run as failed
# TODO check success
mark_aiida_process_dag_run_failed(process_type, dag_run_id)
node.pause()
node.store()


click.secho(f"✓ Successfully paused process {pk}", fg='green', bold=True)
click.echo(f" DAG run '{dag_run_id}' has been marked as failed")

except click.Abort:
raise
except Exception as e:
click.secho(f"✗ Error: {e}", fg='red', err=True)
import traceback
traceback.print_exc()
raise click.Abort()


@airdi_process.command('continue')
@click.argument('pk', type=int)
@click.option(
'--dry-run',
is_flag=True,
help='Show what would be cleared without actually clearing'
)
@click.option(
'--only-failed',
is_flag=True,
help='Only clear failed tasks'
)
def process_play(pk, dry_run, only_failed):
"""
Continues the Airflow DAG run for an AiiDA process.

This clears the DAG run associated with the process node, allowing it
to be re-run or resumed in Airflow.

Arguments:
PK: The primary key (PK) of the AiiDA process node

Examples:
airdi process play 123
airdi process play 456 --dry-run
airdi process play 789 --only-failed
"""
from airflow_provider_aiida.aiida_core import load_profile
from airflow_provider_aiida.utils.airflow_control import clear_aiida_process_dag_run, get_dag_run_id
from aiida.orm import load_node

try:
# Load profile
load_profile()

# Load the process node
click.echo(f"Loading process node {pk}...")
try:
node = load_node(pk)
except Exception as e:
click.secho(f"✗ Error: Could not load node with PK {pk}: {e}", fg='red', err=True)
raise click.Abort()

# Get the DAG run ID from node attributes
dag_run_id = get_dag_run_id(node)
if dag_run_id is None:
click.secho(
f"✗ Error: Node {pk} does not have a DAG run ID attribute. "
f"This node may not be managed by Airflow.",
fg='red',
err=True
)
raise click.Abort()

click.echo(f"DAG run ID: {dag_run_id}")

# Get process class name for DAG ID
process_type = node.process_class

if not process_type:
click.secho(f"✗ Error: Node {pk} does not have a process type", fg='red', err=True)
raise click.Abort()

click.echo(f"Process type: {process_type}")

if dry_run:
click.echo(f"\n[DRY RUN] Clearing DAG run...")
else:
click.echo(f"\nClearing DAG run...")

# Clear the DAG run
# TODO check success
result = clear_aiida_process_dag_run(
process_type,
dag_run_id,
dry_run=dry_run,
only_failed=only_failed
)
node.unpause()

if dry_run:
click.secho(f"✓ [DRY RUN] Would clear process {pk}", fg='yellow', bold=True)
click.echo(f" DAG run '{dag_run_id}' would be cleared")
if result:
click.echo(f" Tasks that would be cleared: {len(result)}")
else:
click.secho(f"✓ Successfully cleared process {pk}", fg='green', bold=True)
click.echo(f" DAG run '{dag_run_id}' has been cleared and can be re-run")

except click.Abort:
raise
except Exception as e:
click.secho(f"✗ Error: {e}", fg='red', err=True)
import traceback
traceback.print_exc()
raise click.Abort()
12 changes: 12 additions & 0 deletions src/airflow_provider_aiida/operators/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
from airflow.models import BaseOperator
from airflow_provider_aiida.triggers.process import ProcStepUntilTerminatedTrigger
from airflow_provider_aiida.utils.airflow_control import set_dag_run_id

from airflow.utils.context import Context

Expand All @@ -24,6 +25,17 @@ def __init__(self,
self.aiida_path = aiida_path

def execute(self, context: Context):
# Add dag_run_id to the process extras and attributes
from aiida import load_profile
load_profile(self.aiida_profile)
from aiida.orm import load_node

node = load_node(self.process_pk)

# Try to get dag_run_id from context and set it on the node

set_dag_run_id(node, context['run_id'])

self.defer(
trigger=ProcStepUntilTerminatedTrigger(
process_pk=self.process_pk,
Expand Down
38 changes: 2 additions & 36 deletions src/airflow_provider_aiida/triggers/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,15 @@
"""

import logging
from typing import Any, AsyncIterator, TYPE_CHECKING
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow_provider_aiida.utils.airflow_control import load_process

from airflow_provider_aiida.aiida_core.engine.runner import AirflowRunner

if TYPE_CHECKING:
from asyncio import AbstractEventLoop

logger = logging.getLogger(__name__)


def get_current_event_loop() -> 'AbstractEventLoop':
import asyncio
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# No running loop - create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop


def load_process(process_pk: int, aiida_profile: str | None, aiida_path: str | None):
"""reenters same state"""
import os
# TODO find a solution that gives understandable error message
# NOTE: this conflicts if profiles from different aiida paths are used
if aiida_path is not None:
os.environ["AIIDA_PATH"] = aiida_path
from aiida import load_profile
load_profile(aiida_profile)
from plumpy.persistence import LoadSaveContext
loop = get_current_event_loop()
runner = AirflowRunner(loop=loop)
saved_state = runner.persister.load_checkpoint(process_pk)
proc = saved_state.unbundle(LoadSaveContext())
proc._runner = runner
# NOTE: Overwrite persisted loop since loop might have changed
proc._loop = loop
return proc


class ProcStepUntilTerminatedTrigger(BaseTrigger):
"""Trigger that executes the AiiDA task_upload_job function."""

Expand Down
Loading