Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
---
environment_specs:
- name: Test
vars:
sumstats_path: gs://genetics-portal-dev-analysis/xg1/LDSC_pyspark/sumstats_subset
ldscore_base_path: gs://genetics-portal-dev-analysis/xg1/LDSC_pyspark/inputs/gnomad/ldscores/full
study_index_path: gs://open-targets-data-releases/25.12/output/study
outputs_dir: gs://genetics-portal-dev-analysis/xg1/LDSC_pyspark/output
gentropy_image: europe-west1-docker.pkg.dev/open-targets-genetics-dev/opentargets/gentropy:3.3.0-dev.2

- name: Prod
vars:
sumstats_path: gs://genetics-portal-dev-analysis/xg1/LDSC_pyspark/sumstats_subset
ldscore_base_path: gs://genetics-portal-dev-analysis/xg1/LDSC_pyspark/inputs/gnomad/ldscores/full
study_index_path: gs://open-targets-data-releases/25.12/output/study
outputs_dir: gs://genetics-portal-dev-analysis/xg1/LDSC_pyspark/output
gentropy_image: europe-west1-docker.pkg.dev/open-targets-genetics-dev/opentargets/gentropy:3.3.0-dev.2

env: Test

nodes:
- id: heritability_estimate
kind: Task
prerequisites: []
cluster: false
google_batch_index_specs:
manifest_generator_label: heritability-estimate
max_task_count: 20
manifest_generator_specs:
commands:
- -c
- gentropy
options:
step: heritability_estimate
step.session.write_mode: overwrite
step.session.output_partitions: 1
step.summary_statistics_input_path: "$INPUT_PARTITION"
step.study_index_input_path: "{study_index_path}"
step.ldscore_base_path: "{ldscore_base_path}"
step.heritability_output_path: "$OUTPUT_PARTITION"
+step.session.extended_spark_conf: '{spark.jars:https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar,spark.hadoop.fs.gs.requester.pays.mode:"CUSTOM",spark.hadoop.fs.gs.requester.pays.buckets:"open-targets-data-releases",spark.hadoop.fs.gs.requester.pays.project.id:"open-targets-genetics-dev"}'

Check failure on line 41 in src/orchestration/dags/config/gwas_catalog_heritability_estimate.yaml

View workflow job for this annotation

GitHub Actions / test

41:201 [line-length] line too long (330 > 200 characters)
manifest_kwargs:
input_glob: "{sumstats_path}"
output_prefix: "{outputs_dir}/heritability_estimates"

google_batch:
entrypoint: /usr/bin/bash
image: "{gentropy_image}"
resource_specs:
cpu_milli: 2000
memory_mib: 7000
boot_disk_mib: 4000
task_specs:
max_retry_count: 2
max_run_duration: "1h"
policy_specs:
machine_type: n1-standard-2
59 changes: 59 additions & 0 deletions src/orchestration/dags/gwas_catalog_heritability_estimate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from datetime import datetime
from pathlib import Path
from typing import Any

from airflow import DAG
from airflow.models.baseoperator import chain

from orchestration.operators.batch.generic import BatchIndexOperator, BatchJobOperator
from orchestration.utils import read_yaml_config, resource_name


def format_config(obj: Any, vars_dict: dict[str, str]) -> Any:
"""Recursively format strings in a config structure."""
if isinstance(obj, str):
return obj.format(**vars_dict)
if isinstance(obj, list):
return [format_config(x, vars_dict) for x in obj]
if isinstance(obj, dict):
return {k: format_config(v, vars_dict) for k, v in obj.items()}
return obj


default_args = {
"owner": "opentargets",
}

with DAG(
dag_id="gentropy_heritability_estimate",
description="Run heritabilbatch_jobsity estimation for harmonised summary statistics using gentropy",
schedule_interval=None,
start_date=datetime(2023, 1, 1),
catchup=False,
default_args=default_args,
tags=["gentropy", "heritability"],
) as dag:
config = read_yaml_config(Path(__file__).parent / "config" / "gwas_catalog_heritability_estimate.yaml")

env_name = config["env"]
env_spec = next(env for env in config["environment_specs"] if env["name"] == env_name)
env_vars = env_spec["vars"]

step_name = "heritability_estimate"
node = next(node for node in config["nodes"] if node["id"] == step_name)
step_config = format_config(node, env_vars)

batch_index = BatchIndexOperator(
task_id=f"{step_name}.batch_index",
batch_index_specs=step_config["google_batch_index_specs"],
)

batch_jobs = BatchJobOperator.partial(
task_id=f"{step_name}.batch_job",
job_name=f"up-{step_name.replace('_', '-')}",
google_batch=step_config["google_batch"],
).expand(batch_index_row=batch_index.output)

chain(batch_index, batch_jobs)
35 changes: 27 additions & 8 deletions src/orchestration/operators/batch/generic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
"""Batch Job main operators."""
"""Batch Job main operators.

This module defines two generic Airflow operators used to submit
partitioned jobs to Google Batch. The ``BatchIndexOperator`` builds
a manifest describing each task and writes the associated environment
variables, while the ``BatchJobOperator`` wraps the Google Batch
client to submit each individual task.

This file is a local snapshot of the corresponding module in the
OpenTargets ``orchestration`` repository. It has been extended to
register an additional manifest generator for heritability estimation.
"""

from __future__ import annotations

Expand All @@ -16,18 +27,25 @@
from orchestration.utils.common import GCP_PROJECT_GENETICS, GCP_REGION
from orchestration.utils.labels import Labels

try:
from orchestration.operators.batch.manifest_generators.heritability_estimate import HeritabilityManifestGenerator
except Exception:
HeritabilityManifestGenerator = None # type: ignore


class BatchIndexOperator(BaseOperator):
"""Operator to prepare google batch job index and partition it into the manifests.

Each manifest prepared by the operator should create an environment for a single batch job.
Each row of the individual manifest should represent individual batch task.
Each manifest prepared by the operator should create an environment for
a single batch job. Each row of the individual manifest should represent
an individual batch task.
"""

# NOTE: here register all manifest generators.
manifest_generator_registry: dict[str, type[ProtoManifestGenerator]] = {
"gwas_catalog_harmonisation": HarmonisationManifestGenerator,
"gentropy-step": GentropyStepGoogleBatchManifestGenerator,
"heritability-estimate": HeritabilityManifestGenerator,
}

def __init__(
Expand All @@ -43,7 +61,7 @@ def __init__(

@classmethod
def get_generator(cls, label: str) -> type[ProtoManifestGenerator]:
"""Get the generator by it's label in the registry."""
"""Get the generator by its label in the registry."""
try:
return cls.manifest_generator_registry[label]
except KeyError:
Expand All @@ -54,8 +72,9 @@ def execute(self, **kwargs) -> list[BatchIndexRow]:
generator = self.manifest_generator.from_generator_config(self.manifest_generator_specs)
index = generator.generate_batch_index()
if not self.max_task_count:
# if specified 0 or not specified in the config, then assume to use the number
# of tasks that is in the output of the BatchIndex.vars_list from manifest generation
# if specified 0 or not specified in the config, then assume to use
# the number of tasks that is in the output of the BatchIndex.vars_list
# from manifest generation
self.max_task_count = len(index.vars_list)
self.log.info(index)
partitioned_index = index.partition(self.max_task_count)
Expand All @@ -65,7 +84,7 @@ def execute(self, **kwargs) -> list[BatchIndexRow]:
class BatchJobOperator(CloudBatchSubmitJobOperator):
"""Generic Batch Job operator.

This operator has to be used in conjunction to the BatchIndexOperator.
This operator has to be used in conjunction with the BatchIndexOperator.
It runs the google batch jobs defined by the BatchIndexOperator.
"""

Expand All @@ -78,7 +97,7 @@ def __init__(
region: str = GCP_REGION,
labels: Labels | None = None,
**kwargs,
):
) -> None:
super().__init__(
project_id=project_id,
region=region,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Google Batch manifest generator for heritability estimation step.

This manifest generator is responsible for preparing the input/output
environment for running the heritability estimation step across many
harmonised summary statistics. It operates by scanning a glob of
harmonised summary statistics on Google Cloud Storage (GCS) and
constructing a per‑file mapping into a corresponding heritability
estimate output path. If the heritability estimate for a given
summary statistic already exists, the file is skipped. The result is
used by the ``BatchIndexOperator`` to build individual Google Batch
tasks.

The generator mirrors the behaviour of
``GentropyStepGoogleBatchManifestGenerator`` used for the locus‑to‑gene
prediction step but adds an existence check on the output path. This
avoids re‑computing heritability estimates for studies which have
already been processed.
"""

from __future__ import annotations

from airflow.exceptions import AirflowSkipException
from airflow.providers.google.cloud.hooks.gcs import GCSHook

from orchestration.operators.batch.batch_index import BatchIndex
from orchestration.operators.batch.manifest_generators import ProtoManifestGenerator
from orchestration.types import ManifestGeneratorSpecs
from orchestration.utils.path import GCSPath


class HeritabilityManifestGenerator(ProtoManifestGenerator):
"""Manifest generator for heritability estimation.

Parameters
----------
commands : list[str]
Shell command fragments used when invoking the gentropy CLI.
options : dict[str, str]
Hydra options controlling the gentropy step. These options are
propagated unchanged into the batch job environment.
manifest_kwargs : dict[str, str]
A mapping containing the keys ``input_glob`` and ``output_prefix``.
``input_glob`` should be a GCS URI glob pointing at the
harmonised summary statistics. ``output_prefix`` is the base
directory under which heritability outputs will be written.
gcp_conn_id : str, optional
Airflow connection ID used for the GCS client. Defaults to
``"google_cloud_default"``.
"""

def __init__(
self,
*,
commands: list[str],
options: dict[str, str],
manifest_kwargs: dict[str, str],
gcp_conn_id: str = "google_cloud_default",
) -> None:
self.commands = commands
self.options = options
self.gcs_hook = GCSHook(gcp_conn_id=gcp_conn_id)
self.input_glob = GCSPath(manifest_kwargs.get("input_glob", ""))
self.output_prefix = GCSPath(manifest_kwargs.get("output_prefix", ""))

@classmethod
def from_generator_config(cls, specs: ManifestGeneratorSpecs) -> "HeritabilityManifestGenerator":
"""Construct a generator from configuration specs.

This method is invoked by the ``BatchIndexOperator`` when
deserialising the manifest generator from the YAML config. It
forwards the specification fields directly to the constructor.
"""
return cls(
commands=specs["commands"],
options=specs["options"],
manifest_kwargs=specs["manifest_kwargs"],
)

def generate_batch_index(self) -> BatchIndex:
"""Create the batch index used by Google Batch.

The index comprises a list of dictionaries, each containing
``INPUT_PARTITION`` and ``OUTPUT_PARTITION`` environment
variables. These entries feed into the batch job to drive
partitioned processing of individual summary statistics files.
"""
vars_list = self.build_vars_list()
return BatchIndex(
vars_list=vars_list,
options=self.options,
commands=self.commands,
)

def build_vars_list(self) -> list[dict[str, str]]:
"""Build one batch job per study directory."""
dataset_root = self.input_glob.gcs_path.rstrip("/")

if not dataset_root.startswith("gs://"):
raise ValueError(f"Expected gs:// path, got {dataset_root}")

without_scheme = dataset_root[len("gs://") :]
bucket_name, root_prefix = without_scheme.split("/", 1)
root_prefix = root_prefix.rstrip("/") + "/"

blobs = self.gcs_hook.list(
bucket_name=bucket_name,
prefix=root_prefix,
)

study_dirs: set[str] = set()

for blob in blobs:
rel = blob[len(root_prefix) :]
if not rel:
continue

parts = rel.split("/")
# Expect STUDY_ID/<file>
if len(parts) >= 2 and parts[0]:
study_dirs.add(parts[0])

vars_list: list[dict[str, str]] = []

for study_dir in sorted(study_dirs):
input_path = f"{dataset_root}/{study_dir}"
output_path = f"{self.output_prefix.gcs_path.rstrip('/')}/{study_dir}"

output_gcs = GCSPath(output_path)
try:
exists = output_gcs.exists()
except Exception:
exists = False

if exists:
continue

vars_list.append({
"INPUT_PARTITION": input_path,
"OUTPUT_PARTITION": output_path,
})

print(f"dataset_root={dataset_root}")
print(f"root_prefix={root_prefix}")
print(f"n_blobs={len(blobs)}")
print(f"n_study_dirs={len(study_dirs)}")
print(f"study_dirs_sample={sorted(study_dirs)[:10]}")
print(f"n_vars_list={len(vars_list)}")

if not vars_list:
raise AirflowSkipException(f"No study directories found to process under {dataset_root}")

return vars_list
8 changes: 3 additions & 5 deletions src/orchestration/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ def create_task_spec(
time_duration = time_to_seconds(task_specs["max_run_duration"])
# See https://docs.cloud.google.com/batch/docs/troubleshooting#reserved-exit-codes
default_lifecycle_policies = [
(
LifecyclePolicy(
action=LifecyclePolicy.Action.RETRY_TASK,
action_condition=LifecyclePolicy.ActionCondition(exit_codes=[50001, 50002, 50003, 50004, 50005]),
),
LifecyclePolicy(
action=LifecyclePolicy.Action.RETRY_TASK,
action_condition=LifecyclePolicy.ActionCondition(exit_codes=[50001, 50002, 50003, 50004, 50005]),
)
]
parameters = {
Expand Down
Loading