Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/decima/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from decima.cli.attributions import cli_attributions
from decima.cli.query_cell import cli_query_cell
from decima.cli.vep import cli_predict_variant_effect
from decima.cli.finetune import cli_finetune
from decima.cli.vep import cli_vep_ensemble
# from decima.cli.finetune import cli_finetune


logger = logging.getLogger("decima")
Expand All @@ -33,6 +33,7 @@ def main():
main.add_command(cli_attributions, name="attributions")
main.add_command(cli_query_cell, name="query-cell")
main.add_command(cli_predict_variant_effect, name="vep")
main.add_command(cli_finetune, name="finetune")
main.add_command(cli_vep_ensemble, name="vep-ensemble")


Expand Down
85 changes: 48 additions & 37 deletions src/decima/cli/finetune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Finetune the Decima model."""

import os
import logging
import click
import anndata
import wandb
Expand All @@ -9,62 +8,74 @@


@click.command()
@click.option("--name", required=True, help="Project name")
@click.option("--dir", required=True, help="Data directory path")
@click.option("--lr", default=0.001, type=float, help="Learning rate")
@click.option("--weight", required=True, type=float, help="Weight parameter")
@click.option("--grad", required=True, type=int, help="Gradient accumulation steps")
@click.option("--replicate", default=0, type=int, help="Replication number")
@click.option("--bs", default=4, type=int, help="Batch size")
def cli_finetune(name, dir, lr, weight, grad, replicate, bs):
@click.option("--name", required=True, help="Name of the run.")
@click.option("--model", default="0", type=str, help="Model path or replication number. If a path is provided, the model will be loaded from the path. If a replication number is provided, the model will be loaded from the replication number.")
@click.option("--matrix-file", required=True, help="Matrix file path.")
@click.option("--h5-file", required=True, help="H5 file path.")
@click.option("--outdir", required=True, help="Output directory path to save model checkpoints.")
@click.option("--learning-rate", default=0.001, type=float, help="Learning rate.")
@click.option("--loss-total-weight", required=True, type=float, help="Total weight parameter for the loss function.")
@click.option("--gradient-accumulation", required=True, type=int, help="Gradient accumulation steps.")
@click.option("--batch-size", default=4, type=int, help="Batch size.")
@click.option("--max-seq-shift", default=5000, type=int, help="Shift augmentation.")
@click.option("--gradient-clipping", default=0.0, type=float, help="Gradient clipping.")
@click.option("--save-top-k", default=1, type=int, help="Number of checkpoints to save.")
@click.option("--epochs", default=1, type=int, help="Number of epochs.")
@click.option("--logger", default="wandb", type=str, help="Logger.")
@click.option("--num-workers", default=16, type=int, help="Number of workers.")
@click.option("--seed", default=0, type=int, help="Random seed.")
def cli_finetune(name, model, matrix_file, h5_file , outdir, learning_rate, loss_total_weight, gradient_accumulation, batch_size, max_seq_shift, gradient_clipping, save_top_k, epochs, logger, num_workers, seed):
"""Finetune the Decima model."""
wandb.login(host="https://genentech.wandb.io")
run = wandb.init(project="decima", dir=name, name=name)

matrix_file = os.path.join(dir, "aggregated.h5ad")
h5_file = os.path.join(dir, "data.h5")
print(f"Data paths: {matrix_file}, {h5_file}")

print("Reading anndata")
train_logger = logger
logger = logging.getLogger("decima")
logger.info(f"Data paths: matrix_file={matrix_file}, h5_file={h5_file}")
logger.info("Reading anndata")
ad = anndata.read_h5ad(matrix_file)

print("Making dataset objects")
logger.info("Making dataset objects")
train_dataset = HDF5Dataset(
h5_file=h5_file,
ad=ad,
key="train",
max_seq_shift=5000,
max_seq_shift=max_seq_shift,
augment_mode="random",
seed=0,
seed=seed,
)
val_dataset = HDF5Dataset(h5_file=h5_file, ad=ad, key="val", max_seq_shift=0)

train_params = {
"optimizer": "adam",
"batch_size": bs,
"num_workers": 16,
"name": name,
"batch_size": batch_size,
"num_workers": num_workers,
"devices": 0,
"logger": "wandb",
"save_dir": dir,
"max_epochs": 15,
"lr": lr,
"total_weight": weight,
"accumulate_grad_batches": grad,
"logger": train_logger,
"save_dir": outdir,
"max_epochs": epochs,
"lr": learning_rate,
"total_weight": loss_total_weight,
"accumulate_grad_batches": gradient_accumulation,
"loss": "poisson_multinomial",
"pairs": ad.uns["disease_pairs"].values,
# "pairs": ad.uns["disease_pairs"].values,
"clip": gradient_clipping,
"save_top_k": save_top_k,
"pin_memory": True,
}
model_params = {
"n_tasks": ad.shape[0],
"replicate": replicate,
"replicate": model,
}
print(f"train_params: {train_params}")
print(f"model_params: {model_params}")
logger.info(f"train_params: {train_params}")
logger.info(f"model_params: {model_params}")

print("Initializing model")
logger.info("Initializing model")
model = LightningModel(model_params=model_params, train_params=train_params)

print("Training")
logger.info("Training")
if logger == "wandb":
wandb.login(host="https://genentech.wandb.io")
run = wandb.init(project="decima", dir=name, name=name)
model.train_on_dataset(train_dataset, val_dataset)
train_dataset.close()
val_dataset.close()
run.finish()
if logger == "wandb":
run.finish()
120 changes: 57 additions & 63 deletions src/decima/cli/predict_genes.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,67 @@
"""Make predictions for all genes using an HDF5 file created by Decima's ``write_hdf5.py``."""

import os
import click
import anndata
import numpy as np
import torch
from decima.constants import DECIMA_CONTEXT_SIZE
from decima.model.lightning import LightningModel
from decima.data.read_hdf5 import list_genes
from decima.data.dataset import HDF5Dataset

# TODO: input can be just a h5ad file rather than a combination of h5 and matrix file.
from decima.tools.inference import predict_gene_expression


@click.command()
@click.option("--device", type=int, help="Which GPU to use.")
@click.option("--ckpts", multiple=True, required=True, help="Path to the model checkpoint(s).")
@click.option("--h5_file", required=True, help="Path to h5 file indexed by genes.")
@click.option("--matrix_file", required=True, help="Path to h5ad file containing genes to predict.")
@click.option("--out_file", required=True, help="Output file path.")
@click.option("-o", "--output", type=click.Path(), help="Path to the output h5ad file.")
@click.option(
"--genes",
type=str,
default=None,
help="List of genes to predict. Default: None (all genes). If provided, only these genes will be predicted.",
)
@click.option(
"-m",
"--model",
type=str,
default="ensemble",
help="Path to the model checkpoint: `0`, `1`, `2`, `3`, `ensemble` or `path/to/model.ckpt`.",
)
@click.option(
"--metadata",
type=click.Path(exists=True),
default=None,
help="Path to the metadata anndata file. Default: None.",
)
@click.option(
"--device",
type=str,
default=None,
help="Device to use. Default: None which automatically selects the best device.",
)
@click.option("--batch-size", type=int, default=8, help="Batch size for the model. Default: 8")
@click.option("--num-workers", type=int, default=4, help="Number of workers for the loader. Default: 4")
@click.option("--max_seq_shift", default=0, help="Maximum jitter for augmentation.")
def cli_predict_genes(device, ckpts, h5_file, matrix_file, out_file, max_seq_shift):
"""Make predictions for all genes."""
torch.set_float32_matmul_precision("medium")

# TODO: device is unused, set the device appropriately
os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
device = torch.device(0)
@click.option("--genome", type=str, default="hg38", help="Genome build. Default: hg38.")
@click.option(
"--save-replicates",
is_flag=True,
help="Save the replicates in the output parquet file. Default: False.",
)
def cli_predict_genes(
output, genes, model, metadata, device, batch_size, num_workers, max_seq_shift, genome, save_replicates
):
if model in ["0", "1", "2", "3"]:
model = int(model)

print("Loading anndata")
ad = anndata.read_h5ad(matrix_file)
assert np.all(list_genes(h5_file, key=None) == ad.var_names.tolist())

print("Making dataset")
ds = HDF5Dataset(
key=None,
h5_file=h5_file,
ad=ad,
seq_len=DECIMA_CONTEXT_SIZE,
max_seq_shift=max_seq_shift,
)
if isinstance(device, str) and device.isdigit():
device = int(device)

print("Loading models from checkpoint")
models = [LightningModel.load_from_checkpoint(f).eval() for f in ckpts]
if genes is not None:
genes = genes.split(",")

print("Computing predictions")
preds = (
np.stack([model.predict_on_dataset(ds, devices=0, batch_size=6, num_workers=16) for model in models]).mean(0).T
)
ad.layers["preds"] = preds
if save_replicates and (model != "ensemble"):
raise ValueError("`--save-replicates` is only supported for ensemble model (`--model ensemble`).")

print("Computing correlations per gene")
ad.var["pearson"] = [np.corrcoef(ad.X[:, i], ad.layers["preds"][:, i])[0, 1] for i in range(ad.shape[1])]
ad.var["size_factor_pearson"] = [np.corrcoef(ad.X[:, i], ad.obs["size_factor"])[0, 1] for i in range(ad.shape[1])]
print(
f"Mean Pearson Correlation per gene: True: {ad.var.pearson.mean().round(2)} Size Factor: {ad.var.size_factor_pearson.mean().round(2)}"
ad = predict_gene_expression(
genes=genes,
model=model,
metadata_anndata=metadata,
device=device,
batch_size=batch_size,
num_workers=num_workers,
max_seq_shift=max_seq_shift,
genome=genome,
save_replicates=save_replicates,
)

print("Computing correlation per track")
for dataset in ad.var.dataset.unique():
key = f"{dataset}_pearson"
ad.obs[key] = [
np.corrcoef(
ad[i, ad.var.dataset == dataset].X,
ad[i, ad.var.dataset == dataset].layers["preds"],
)[0, 1]
for i in range(ad.shape[0])
]
print(f"Mean Pearson Correlation per pseudobulk over {dataset} genes: {ad.obs[key].mean().round(2)}")

print("Saved")
ad.write_h5ad(out_file)
ad.write_h5ad(output)
26 changes: 23 additions & 3 deletions src/decima/core/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,27 @@ def predicted_expression_matrix(self, genes: Optional[List[str]] = None) -> pd.D
else:
return pd.DataFrame(self.anndata[:, genes].layers["preds"], index=self.cells, columns=genes)

def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None) -> torch.Tensor:
def _pad_gene_metadata(self, gene_meta: pd.Series, padding: int = 0) -> pd.Series:
"""
Pad gene metadata with padding.

Args:
gene_meta: Gene metadata
padding: Padding to add to the gene metadata

Returns:
pd.Series: Padded gene metadata
"""
gene_meta = gene_meta.copy()
gene_meta["start"] = gene_meta["start"] - padding
gene_meta["end"] = gene_meta["end"] + padding
gene_meta["gene_start"] = gene_meta["gene_start"] - padding
gene_meta["gene_end"] = gene_meta["gene_end"] + padding
gene_meta["gene_mask_start"] = gene_meta["gene_mask_start"] - padding
gene_meta["gene_mask_end"] = gene_meta["gene_mask_end"] - padding
return gene_meta

def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None, padding: int = 0) -> torch.Tensor:
"""Prepare one-hot encoding for a gene.

Args:
Expand All @@ -167,15 +187,15 @@ def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None) -> t
torch.Tensor: One-hot encoding of the gene
"""
assert gene in self.genes, f"{gene} is not in the anndata object"
gene_meta = self.gene_metadata.loc[gene]
gene_meta = self._pad_gene_metadata(self.gene_metadata.loc[gene], padding)

if variants is None:
seq = intervals_to_strings(gene_meta, genome="hg38")
gene_start, gene_end = gene_meta.gene_mask_start, gene_meta.gene_mask_end
else:
seq, (gene_start, gene_end) = prepare_seq_alt_allele(gene_meta, variants)

mask = np.zeros(shape=(1, DECIMA_CONTEXT_SIZE))
mask = np.zeros(shape=(1, DECIMA_CONTEXT_SIZE + padding * 2))
mask[0, gene_start:gene_end] += 1
mask = torch.from_numpy(mask).float()

Expand Down
Loading
Loading