Skip to content

Commit f739160

Browse files
MuhammedHasanMuhammed Hasan Celikavantikalal
authored
decima version=0.2.2, gene_expression prediction, sequence shifting and fine-tuning (#19)
* ensemble vep init * backward compability of grelu, ensembling, testcases, custom fasta * gene dataset * gene expression prediction and sequence shifting * fix testcase * conflig * Changes related to data processing and fine-tuning new models (#16) * enable finetune via cli * split input and output directories * add mygene * added ensembl * added N padding * add more params * added args to cli finetune * add csv logging * add csv logging * add run name to checkpoints * gene pearson metric * training 202506 * added new params * added topk * reset unnecessary changes * reset unnecessary changes * reset unnecessary changes * reset unnecessary changes * reset unnecessary changes * fixed savek typo * more useful print * finetuning updates --------- Co-authored-by: Muhammed Hasan Celik <celik.muhammed_hasan@gene.com> * fix testcases * branch review updates --------- Co-authored-by: Muhammed Hasan Celik <celik.muhammed_hasan@gene.com> Co-authored-by: Avantika Lal <avantikalal1990@gmail.com>
1 parent 3dc14d7 commit f739160

File tree

15 files changed

+520
-159
lines changed

15 files changed

+520
-159
lines changed

src/decima/cli/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from decima.cli.attributions import cli_attributions
77
from decima.cli.query_cell import cli_query_cell
88
from decima.cli.vep import cli_predict_variant_effect
9+
from decima.cli.finetune import cli_finetune
910
from decima.cli.vep import cli_vep_ensemble
10-
# from decima.cli.finetune import cli_finetune
1111

1212

1313
logger = logging.getLogger("decima")
@@ -33,6 +33,7 @@ def main():
3333
main.add_command(cli_attributions, name="attributions")
3434
main.add_command(cli_query_cell, name="query-cell")
3535
main.add_command(cli_predict_variant_effect, name="vep")
36+
main.add_command(cli_finetune, name="finetune")
3637
main.add_command(cli_vep_ensemble, name="vep-ensemble")
3738

3839

src/decima/cli/finetune.py

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Finetune the Decima model."""
2-
3-
import os
2+
import logging
43
import click
54
import anndata
65
import wandb
@@ -9,62 +8,74 @@
98

109

1110
@click.command()
12-
@click.option("--name", required=True, help="Project name")
13-
@click.option("--dir", required=True, help="Data directory path")
14-
@click.option("--lr", default=0.001, type=float, help="Learning rate")
15-
@click.option("--weight", required=True, type=float, help="Weight parameter")
16-
@click.option("--grad", required=True, type=int, help="Gradient accumulation steps")
17-
@click.option("--replicate", default=0, type=int, help="Replication number")
18-
@click.option("--bs", default=4, type=int, help="Batch size")
19-
def cli_finetune(name, dir, lr, weight, grad, replicate, bs):
11+
@click.option("--name", required=True, help="Name of the run.")
12+
@click.option("--model", default="0", type=str, help="Model path or replication number. If a path is provided, the model will be loaded from the path. If a replication number is provided, the model will be loaded from the replication number.")
13+
@click.option("--matrix-file", required=True, help="Matrix file path.")
14+
@click.option("--h5-file", required=True, help="H5 file path.")
15+
@click.option("--outdir", required=True, help="Output directory path to save model checkpoints.")
16+
@click.option("--learning-rate", default=0.001, type=float, help="Learning rate.")
17+
@click.option("--loss-total-weight", required=True, type=float, help="Total weight parameter for the loss function.")
18+
@click.option("--gradient-accumulation", required=True, type=int, help="Gradient accumulation steps.")
19+
@click.option("--batch-size", default=4, type=int, help="Batch size.")
20+
@click.option("--max-seq-shift", default=5000, type=int, help="Shift augmentation.")
21+
@click.option("--gradient-clipping", default=0.0, type=float, help="Gradient clipping.")
22+
@click.option("--save-top-k", default=1, type=int, help="Number of checkpoints to save.")
23+
@click.option("--epochs", default=1, type=int, help="Number of epochs.")
24+
@click.option("--logger", default="wandb", type=str, help="Logger.")
25+
@click.option("--num-workers", default=16, type=int, help="Number of workers.")
26+
@click.option("--seed", default=0, type=int, help="Random seed.")
27+
def cli_finetune(name, model, matrix_file, h5_file , outdir, learning_rate, loss_total_weight, gradient_accumulation, batch_size, max_seq_shift, gradient_clipping, save_top_k, epochs, logger, num_workers, seed):
2028
"""Finetune the Decima model."""
21-
wandb.login(host="https://genentech.wandb.io")
22-
run = wandb.init(project="decima", dir=name, name=name)
23-
24-
matrix_file = os.path.join(dir, "aggregated.h5ad")
25-
h5_file = os.path.join(dir, "data.h5")
26-
print(f"Data paths: {matrix_file}, {h5_file}")
27-
28-
print("Reading anndata")
29+
train_logger = logger
30+
logger = logging.getLogger("decima")
31+
logger.info(f"Data paths: matrix_file={matrix_file}, h5_file={h5_file}")
32+
logger.info("Reading anndata")
2933
ad = anndata.read_h5ad(matrix_file)
3034

31-
print("Making dataset objects")
35+
logger.info("Making dataset objects")
3236
train_dataset = HDF5Dataset(
3337
h5_file=h5_file,
3438
ad=ad,
3539
key="train",
36-
max_seq_shift=5000,
40+
max_seq_shift=max_seq_shift,
3741
augment_mode="random",
38-
seed=0,
42+
seed=seed,
3943
)
4044
val_dataset = HDF5Dataset(h5_file=h5_file, ad=ad, key="val", max_seq_shift=0)
4145

4246
train_params = {
43-
"optimizer": "adam",
44-
"batch_size": bs,
45-
"num_workers": 16,
47+
"name": name,
48+
"batch_size": batch_size,
49+
"num_workers": num_workers,
4650
"devices": 0,
47-
"logger": "wandb",
48-
"save_dir": dir,
49-
"max_epochs": 15,
50-
"lr": lr,
51-
"total_weight": weight,
52-
"accumulate_grad_batches": grad,
51+
"logger": train_logger,
52+
"save_dir": outdir,
53+
"max_epochs": epochs,
54+
"lr": learning_rate,
55+
"total_weight": loss_total_weight,
56+
"accumulate_grad_batches": gradient_accumulation,
5357
"loss": "poisson_multinomial",
54-
"pairs": ad.uns["disease_pairs"].values,
58+
# "pairs": ad.uns["disease_pairs"].values,
59+
"clip": gradient_clipping,
60+
"save_top_k": save_top_k,
61+
"pin_memory": True,
5562
}
5663
model_params = {
5764
"n_tasks": ad.shape[0],
58-
"replicate": replicate,
65+
"replicate": model,
5966
}
60-
print(f"train_params: {train_params}")
61-
print(f"model_params: {model_params}")
67+
logger.info(f"train_params: {train_params}")
68+
logger.info(f"model_params: {model_params}")
6269

63-
print("Initializing model")
70+
logger.info("Initializing model")
6471
model = LightningModel(model_params=model_params, train_params=train_params)
6572

66-
print("Training")
73+
logger.info("Training")
74+
if logger == "wandb":
75+
wandb.login(host="https://genentech.wandb.io")
76+
run = wandb.init(project="decima", dir=name, name=name)
6777
model.train_on_dataset(train_dataset, val_dataset)
6878
train_dataset.close()
6979
val_dataset.close()
70-
run.finish()
80+
if logger == "wandb":
81+
run.finish()

src/decima/cli/predict_genes.py

Lines changed: 57 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,67 @@
1-
"""Make predictions for all genes using an HDF5 file created by Decima's ``write_hdf5.py``."""
2-
3-
import os
41
import click
5-
import anndata
6-
import numpy as np
7-
import torch
8-
from decima.constants import DECIMA_CONTEXT_SIZE
9-
from decima.model.lightning import LightningModel
10-
from decima.data.read_hdf5 import list_genes
11-
from decima.data.dataset import HDF5Dataset
12-
13-
# TODO: input can be just a h5ad file rather than a combination of h5 and matrix file.
2+
from decima.tools.inference import predict_gene_expression
143

154

165
@click.command()
17-
@click.option("--device", type=int, help="Which GPU to use.")
18-
@click.option("--ckpts", multiple=True, required=True, help="Path to the model checkpoint(s).")
19-
@click.option("--h5_file", required=True, help="Path to h5 file indexed by genes.")
20-
@click.option("--matrix_file", required=True, help="Path to h5ad file containing genes to predict.")
21-
@click.option("--out_file", required=True, help="Output file path.")
6+
@click.option("-o", "--output", type=click.Path(), help="Path to the output h5ad file.")
7+
@click.option(
8+
"--genes",
9+
type=str,
10+
default=None,
11+
help="List of genes to predict. Default: None (all genes). If provided, only these genes will be predicted.",
12+
)
13+
@click.option(
14+
"-m",
15+
"--model",
16+
type=str,
17+
default="ensemble",
18+
help="Path to the model checkpoint: `0`, `1`, `2`, `3`, `ensemble` or `path/to/model.ckpt`.",
19+
)
20+
@click.option(
21+
"--metadata",
22+
type=click.Path(exists=True),
23+
default=None,
24+
help="Path to the metadata anndata file. Default: None.",
25+
)
26+
@click.option(
27+
"--device",
28+
type=str,
29+
default=None,
30+
help="Device to use. Default: None which automatically selects the best device.",
31+
)
32+
@click.option("--batch-size", type=int, default=8, help="Batch size for the model. Default: 8")
33+
@click.option("--num-workers", type=int, default=4, help="Number of workers for the loader. Default: 4")
2234
@click.option("--max_seq_shift", default=0, help="Maximum jitter for augmentation.")
23-
def cli_predict_genes(device, ckpts, h5_file, matrix_file, out_file, max_seq_shift):
24-
"""Make predictions for all genes."""
25-
torch.set_float32_matmul_precision("medium")
26-
27-
# TODO: device is unused, set the device appropriately
28-
os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
29-
device = torch.device(0)
35+
@click.option("--genome", type=str, default="hg38", help="Genome build. Default: hg38.")
36+
@click.option(
37+
"--save-replicates",
38+
is_flag=True,
39+
help="Save the replicates in the output parquet file. Default: False.",
40+
)
41+
def cli_predict_genes(
42+
output, genes, model, metadata, device, batch_size, num_workers, max_seq_shift, genome, save_replicates
43+
):
44+
if model in ["0", "1", "2", "3"]:
45+
model = int(model)
3046

31-
print("Loading anndata")
32-
ad = anndata.read_h5ad(matrix_file)
33-
assert np.all(list_genes(h5_file, key=None) == ad.var_names.tolist())
34-
35-
print("Making dataset")
36-
ds = HDF5Dataset(
37-
key=None,
38-
h5_file=h5_file,
39-
ad=ad,
40-
seq_len=DECIMA_CONTEXT_SIZE,
41-
max_seq_shift=max_seq_shift,
42-
)
47+
if isinstance(device, str) and device.isdigit():
48+
device = int(device)
4349

44-
print("Loading models from checkpoint")
45-
models = [LightningModel.load_from_checkpoint(f).eval() for f in ckpts]
50+
if genes is not None:
51+
genes = genes.split(",")
4652

47-
print("Computing predictions")
48-
preds = (
49-
np.stack([model.predict_on_dataset(ds, devices=0, batch_size=6, num_workers=16) for model in models]).mean(0).T
50-
)
51-
ad.layers["preds"] = preds
53+
if save_replicates and (model != "ensemble"):
54+
raise ValueError("`--save-replicates` is only supported for ensemble model (`--model ensemble`).")
5255

53-
print("Computing correlations per gene")
54-
ad.var["pearson"] = [np.corrcoef(ad.X[:, i], ad.layers["preds"][:, i])[0, 1] for i in range(ad.shape[1])]
55-
ad.var["size_factor_pearson"] = [np.corrcoef(ad.X[:, i], ad.obs["size_factor"])[0, 1] for i in range(ad.shape[1])]
56-
print(
57-
f"Mean Pearson Correlation per gene: True: {ad.var.pearson.mean().round(2)} Size Factor: {ad.var.size_factor_pearson.mean().round(2)}"
56+
ad = predict_gene_expression(
57+
genes=genes,
58+
model=model,
59+
metadata_anndata=metadata,
60+
device=device,
61+
batch_size=batch_size,
62+
num_workers=num_workers,
63+
max_seq_shift=max_seq_shift,
64+
genome=genome,
65+
save_replicates=save_replicates,
5866
)
59-
60-
print("Computing correlation per track")
61-
for dataset in ad.var.dataset.unique():
62-
key = f"{dataset}_pearson"
63-
ad.obs[key] = [
64-
np.corrcoef(
65-
ad[i, ad.var.dataset == dataset].X,
66-
ad[i, ad.var.dataset == dataset].layers["preds"],
67-
)[0, 1]
68-
for i in range(ad.shape[0])
69-
]
70-
print(f"Mean Pearson Correlation per pseudobulk over {dataset} genes: {ad.obs[key].mean().round(2)}")
71-
72-
print("Saved")
73-
ad.write_h5ad(out_file)
67+
ad.write_h5ad(output)

src/decima/core/result.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,25 @@ def predicted_expression_matrix(self, genes: Optional[List[str]] = None) -> pd.D
157157
else:
158158
return pd.DataFrame(self.anndata[:, genes].layers["preds"], index=self.cells, columns=genes)
159159

160-
def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None) -> torch.Tensor:
160+
def _pad_gene_metadata(self, gene_meta: pd.Series, padding: int = 0) -> pd.Series:
161+
"""
162+
Pad gene metadata with padding.
163+
164+
Args:
165+
gene_meta: Gene metadata
166+
padding: Padding to add to the gene metadata
167+
168+
Returns:
169+
pd.Series: Padded gene metadata
170+
"""
171+
gene_meta = gene_meta.copy()
172+
gene_meta["start"] = gene_meta["start"] - padding
173+
gene_meta["end"] = gene_meta["end"] + padding
174+
gene_meta["gene_mask_start"] = gene_meta["gene_mask_start"] + padding
175+
gene_meta["gene_mask_end"] = gene_meta["gene_mask_end"] + padding
176+
return gene_meta
177+
178+
def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None, padding: int = 0) -> torch.Tensor:
161179
"""Prepare one-hot encoding for a gene.
162180
163181
Args:
@@ -167,15 +185,15 @@ def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None) -> t
167185
torch.Tensor: One-hot encoding of the gene
168186
"""
169187
assert gene in self.genes, f"{gene} is not in the anndata object"
170-
gene_meta = self.gene_metadata.loc[gene]
188+
gene_meta = self._pad_gene_metadata(self.gene_metadata.loc[gene], padding)
171189

172190
if variants is None:
173191
seq = intervals_to_strings(gene_meta, genome="hg38")
174192
gene_start, gene_end = gene_meta.gene_mask_start, gene_meta.gene_mask_end
175193
else:
176194
seq, (gene_start, gene_end) = prepare_seq_alt_allele(gene_meta, variants)
177195

178-
mask = np.zeros(shape=(1, DECIMA_CONTEXT_SIZE))
196+
mask = np.zeros(shape=(1, DECIMA_CONTEXT_SIZE + padding * 2))
179197
mask[0, gene_start:gene_end] += 1
180198
mask = torch.from_numpy(mask).float()
181199

0 commit comments

Comments
 (0)