Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 61 additions & 36 deletions src/ingest-pipeline/airflow/dags/phenocycler_deepcell.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
import re
import utils

from airflow.decorators import task
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator, BranchPythonOperator
import urllib.parse
from datetime import datetime, timedelta
from pathlib import Path

from extra_utils import build_tag_containers
from hubmap_operators.common_operators import (
CreateTmpDirOperator,
LogInfoOperator,
)
from hubmap_operators.flex_multi_dag_run import FlexMultiDagRunOperator
from status_change.callbacks.failure_callback import FailureCallback
from utils import (
get_queue_resource,
get_uuid_for_error,
HMDAG,
get_tmp_dir_path,
get_preserve_scratch_resource,
)
from utils import build_dataset_name as inner_build_dataset_name
from utils import (
downstream_workflow_iter,
get_absolute_workflow,
build_dataset_name as inner_build_dataset_name,
get_parent_data_dir,
get_auth_tok,
get_cwl_cmd_from_workflows,
join_quote_command_str,
get_dataset_uuid,
)
from hubmap_operators.flex_multi_dag_run import FlexMultiDagRunOperator
from hubmap_operators.common_operators import (
LogInfoOperator,
CreateTmpDirOperator,
get_parent_data_dir,
get_preserve_scratch_resource,
get_queue_resource,
get_tmp_dir_path,
get_uuid_for_error,
join_quote_command_str,
post_to_slack_notify,
pythonop_maybe_keep,
pythonop_set_dataset_state,
)

from extra_utils import build_tag_containers
from airflow.configuration import conf as airflow_conf
from airflow.decorators import task
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import BranchPythonOperator, PythonOperator

SLACK_NOTIFY_CHANNEL = "C07P2P1D5LP"

default_args = {
"owner": "hubmap",
Expand All @@ -40,7 +50,7 @@
"retry_delay": timedelta(minutes=1),
"xcom_push": True,
"queue": get_queue_resource("phenocycler_deepcell"),
"on_failure_callback": utils.create_dataset_state_error_callback(get_uuid_for_error),
"on_failure_callback": FailureCallback(__name__, get_uuid_for_error),
}

with HMDAG(
Expand Down Expand Up @@ -151,7 +161,7 @@ def build_cwltool_cwl_segmentation(**kwargs):

t_maybe_keep_cwl_segmentation = BranchPythonOperator(
task_id="maybe_keep_cwl_segmentation",
python_callable=utils.pythonop_maybe_keep,
python_callable=pythonop_maybe_keep,
provide_context=True,
op_kwargs={
"next_op": "prepare_stellar_pre_convert",
Expand Down Expand Up @@ -211,30 +221,44 @@ def build_cwltool_cwl_stellar_pre_convert(**kwargs):
""",
)

# TODO: t_notify_user_stellar_pre_convert

t_maybe_keep_cwl_stellar_pre_convert = BranchPythonOperator(
task_id="maybe_keep_cwl_stellar_pre_convert",
python_callable=utils.pythonop_maybe_keep,
python_callable=pythonop_maybe_keep,
provide_context=True,
op_kwargs={
"next_op": "prepare_cell_count_cmd",
"next_op": "notify_user_stellar_pre_convert",
"bail_op": "set_dataset_error",
"test_op": "pipeline_exec_cwl_stellar_pre_convert",
},
)

@task
def notify_user_stellar_pre_convert(**kwargs):
run_id = kwargs["run_id"]
conf = airflow_conf.as_dict().get("webserver", {})
run_url = f"{conf.get('base_url', '')}:{conf.get('web_server_port', '')}/dags/phenocycler_deepcell_segmentation/grid?dag_run_id={urllib.parse.quote(run_id)}"
message = f"STELLAR pre-convert step succeeded in run <{run_url}|{run_id}>."
if kwargs["dag_run"].conf.get("dryrun"):
message = "[dryrun] " + message
post_to_slack_notify(get_auth_tok(**kwargs), message, SLACK_NOTIFY_CHANNEL)

t_notify_user_stellar_pre_convert = notify_user_stellar_pre_convert()

prepare_cell_count_cmd = EmptyOperator(task_id="prepare_cell_count_cmd")

@task(task_id="cell_count_cmd")
def build_cell_count_cmd(**kwargs):
tmpdir = get_tmp_dir_path(kwargs["run_id"])
print("tmpdir: ", tmpdir)
pattern = r"\: (\d+),$"
num_cells_re = None
with open(Path(tmpdir, "session.log"), "r") as f:
for line in f:
if "num_cells" in line:
num_cells = re.search(pattern, line).group(1)
num_cells_re = re.search(pattern, line)
if not num_cells_re:
raise Exception("'num_cells' not found in session.log file")
num_cells = num_cells_re.group(1)
print("num_cells: ", num_cells)
kwargs["ti"].xcom_push(key="small_sprm", value=1 if int(num_cells) > 200000 else 0)
return 0
Expand All @@ -243,7 +267,7 @@ def build_cell_count_cmd(**kwargs):

t_maybe_start_small_sprm = BranchPythonOperator(
task_id="maybe_start_small",
python_callable=utils.pythonop_maybe_keep,
python_callable=pythonop_maybe_keep,
provide_context=True,
op_kwargs={
"next_op": "trigger_phenocycler_small",
Expand All @@ -254,14 +278,14 @@ def build_cell_count_cmd(**kwargs):
)

def trigger_phenocycler(**kwargs):
collection_type = kwargs.get("collection_type")
assay_type = kwargs.get("assay_type")
collection_type = kwargs.get("collection_type", "")
assay_type = kwargs.get("assay_type", "")
payload = {
"tmp_dir": get_tmp_dir_path(kwargs.get("run_id")),
"parent_submission_id": kwargs.get("dag_run").conf.get("parent_submission_id"),
"parent_lz_path": kwargs.get("dag_run").conf.get("parent_lz_path"),
"previous_version_uuid": kwargs.get("dag_run").conf.get("previous_version_uuid"),
"metadata": kwargs.get("dag_run").conf.get("metadata"),
"tmp_dir": get_tmp_dir_path(kwargs["run_id"]),
"parent_submission_id": kwargs["dag_run"].conf.get("parent_submission_id"),
"parent_lz_path": kwargs["dag_run"].conf.get("parent_lz_path"),
"previous_version_uuid": kwargs["dag_run"].conf.get("previous_version_uuid"),
"metadata": kwargs["dag_run"].conf.get("metadata"),
"crypt_auth_tok": kwargs["dag_run"].conf.get("crypt_auth_tok"),
"workflows": kwargs["ti"].xcom_pull(
task_ids="build_cwl_stellar_pre_convert", key="cwl_workflows"
Expand All @@ -270,7 +294,7 @@ def trigger_phenocycler(**kwargs):
print(
f"Collection_type: {collection_type} with assay_type {assay_type} and payload: {payload}",
)
for next_dag in utils.downstream_workflow_iter(collection_type, assay_type):
for next_dag in downstream_workflow_iter(collection_type, assay_type):
yield next_dag, payload

t_trigger_phenocyler_small = FlexMultiDagRunOperator(
Expand All @@ -294,7 +318,7 @@ def trigger_phenocycler(**kwargs):

t_set_dataset_error = PythonOperator(
task_id="set_dataset_error",
python_callable=utils.pythonop_set_dataset_state,
python_callable=pythonop_set_dataset_state,
provide_context=True,
trigger_rule="all_done",
op_kwargs={
Expand All @@ -316,6 +340,7 @@ def trigger_phenocycler(**kwargs):
>> t_pipeline_exec_cwl_stellar_pre_convert
>> t_copy_stellar_pre_convert_data
>> t_maybe_keep_cwl_stellar_pre_convert
>> t_notify_user_stellar_pre_convert
>> prepare_cell_count_cmd
>> cell_count_cmd
>> t_maybe_start_small_sprm
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
from typing import Callable

from status_change.status_utils import get_submission_context
from utils import get_auth_tok
Expand All @@ -16,8 +17,9 @@ class AirflowCallback(ABC):
)
"""

def __init__(self, module_name: str):
def __init__(self, module_name: str, dataset_uuid_callable: Callable | None = None):
self.called_from = module_name
self.dataset_uuid_callable = dataset_uuid_callable

def __call__(self, context: dict):
"""
Expand All @@ -28,7 +30,11 @@ def __call__(self, context: dict):
raise NotImplementedError

def get_data(self, context: dict):
self.uuid = context["task_instance"].xcom_pull(key="uuid")

if self.dataset_uuid_callable:
self.uuid = self.dataset_uuid_callable(**context)
else:
self.uuid = context["task_instance"].xcom_pull(key="uuid")
context["uuid"] = self.uuid
self.auth_tok = get_auth_tok(**context)
self.dag_run = context.get("dag_run")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
MessageManager,
Statuses,
get_env,
post_to_slack_notify,
slack_channels_testing,
split_error_counts,
)
Expand Down Expand Up @@ -104,6 +103,8 @@ def get_message_class(self, msg_type: Statuses) -> Optional[SlackMessage]:
return main_class(self.uuid, self.token)

def update(self):
from utils import post_to_slack_notify

if not self.message_class:
raise EntityUpdateException("Can't update Slack without message class, exiting.")
message = self.get_message()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,14 +309,6 @@ def get_organ(uuid: str, token: str) -> str:
return ""


def post_to_slack_notify(token: str, message: str, channel: str):
http_hook = HttpHook("POST", http_conn_id="ingest_api_connection")
payload = json.dumps({"message": message, "channel": channel})
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
response = http_hook.run("/notify", payload, headers)
response.raise_for_status()


def get_ancestors(uuid: str, token: str) -> dict:
endpoint = f"/ancestors/{uuid}"
headers = get_headers(token)
Expand Down
9 changes: 8 additions & 1 deletion src/ingest-pipeline/airflow/dags/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,10 +1964,17 @@ def search_api_reindex(uuid, **kwargs):
except HTTPError as e:
print(f"Redinex for {uuid} failed. ERROR: {e}")
return False

return True


def post_to_slack_notify(token: str, message: str, channel: str):
http_hook = HttpHook("POST", http_conn_id="ingest_api_connection")
payload = json.dumps({"message": message, "channel": channel})
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
response = http_hook.run("/notify", payload, headers)
response.raise_for_status()


def main():
"""
This provides some unit tests. To run it, you will need to define the
Expand Down
2 changes: 1 addition & 1 deletion tests/test_status_changer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def setUp(self):
"status_change.slack.base.get_submission_context", return_value=good_upload_context
)
self.slack_update = patch("status_change.slack_manager.SlackManager.update")
self.slack_post = patch("status_change.slack_manager.post_to_slack_notify")
self.slack_post = patch("utils.post_to_slack_notify")
self.dib_update = patch(
"status_change.data_ingest_board_manager.DataIngestBoardManager.update"
)
Expand Down