-
Notifications
You must be signed in to change notification settings - Fork 8
decima version=0.2.2, gene_expression prediction, sequence shifting and fine-tuning #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
f371417
ensemble vep init
bda1cd0
backward compability of grelu, ensembling, testcases, custom fasta
6508480
gene dataset
329365c
gene expression prediction and sequence shifting
c7582ba
fix testcase
cf700ac
conflig
2542b28
merge conflict
b0f1c18
Changes related to data processing and fine-tuning new models (#16)
avantikalal d985f1c
fix testcases
83a12fa
branch review updates
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,73 +1,67 @@ | ||
| """Make predictions for all genes using an HDF5 file created by Decima's ``write_hdf5.py``.""" | ||
|
|
||
| import os | ||
| import click | ||
| import anndata | ||
| import numpy as np | ||
| import torch | ||
| from decima.constants import DECIMA_CONTEXT_SIZE | ||
| from decima.model.lightning import LightningModel | ||
| from decima.data.read_hdf5 import list_genes | ||
| from decima.data.dataset import HDF5Dataset | ||
|
|
||
| # TODO: input can be just a h5ad file rather than a combination of h5 and matrix file. | ||
| from decima.tools.inference import predict_gene_expression | ||
|
|
||
|
|
||
| @click.command() | ||
| @click.option("--device", type=int, help="Which GPU to use.") | ||
| @click.option("--ckpts", multiple=True, required=True, help="Path to the model checkpoint(s).") | ||
| @click.option("--h5_file", required=True, help="Path to h5 file indexed by genes.") | ||
| @click.option("--matrix_file", required=True, help="Path to h5ad file containing genes to predict.") | ||
| @click.option("--out_file", required=True, help="Output file path.") | ||
| @click.option("-o", "--output", type=click.Path(), help="Path to the output h5ad file.") | ||
| @click.option( | ||
| "--genes", | ||
| type=str, | ||
| default=None, | ||
| help="List of genes to predict. Default: None (all genes). If provided, only these genes will be predicted.", | ||
| ) | ||
| @click.option( | ||
| "-m", | ||
| "--model", | ||
| type=str, | ||
| default="ensemble", | ||
| help="Path to the model checkpoint: `0`, `1`, `2`, `3`, `ensemble` or `path/to/model.ckpt`.", | ||
| ) | ||
| @click.option( | ||
| "--metadata", | ||
| type=click.Path(exists=True), | ||
| default=None, | ||
| help="Path to the metadata anndata file. Default: None.", | ||
| ) | ||
| @click.option( | ||
| "--device", | ||
| type=str, | ||
| default=None, | ||
| help="Device to use. Default: None which automatically selects the best device.", | ||
| ) | ||
| @click.option("--batch-size", type=int, default=8, help="Batch size for the model. Default: 8") | ||
| @click.option("--num-workers", type=int, default=4, help="Number of workers for the loader. Default: 4") | ||
| @click.option("--max_seq_shift", default=0, help="Maximum jitter for augmentation.") | ||
| def cli_predict_genes(device, ckpts, h5_file, matrix_file, out_file, max_seq_shift): | ||
| """Make predictions for all genes.""" | ||
| torch.set_float32_matmul_precision("medium") | ||
|
|
||
| # TODO: device is unused, set the device appropriately | ||
| os.environ["CUDA_VISIBLE_DEVICES"] = str(device) | ||
| device = torch.device(0) | ||
| @click.option("--genome", type=str, default="hg38", help="Genome build. Default: hg38.") | ||
| @click.option( | ||
| "--save-replicates", | ||
| is_flag=True, | ||
| help="Save the replicates in the output parquet file. Default: False.", | ||
| ) | ||
| def cli_predict_genes( | ||
| output, genes, model, metadata, device, batch_size, num_workers, max_seq_shift, genome, save_replicates | ||
| ): | ||
| if model in ["0", "1", "2", "3"]: | ||
| model = int(model) | ||
|
|
||
| print("Loading anndata") | ||
| ad = anndata.read_h5ad(matrix_file) | ||
| assert np.all(list_genes(h5_file, key=None) == ad.var_names.tolist()) | ||
|
|
||
| print("Making dataset") | ||
| ds = HDF5Dataset( | ||
| key=None, | ||
| h5_file=h5_file, | ||
| ad=ad, | ||
| seq_len=DECIMA_CONTEXT_SIZE, | ||
| max_seq_shift=max_seq_shift, | ||
| ) | ||
| if isinstance(device, str) and device.isdigit(): | ||
| device = int(device) | ||
|
|
||
| print("Loading models from checkpoint") | ||
| models = [LightningModel.load_from_checkpoint(f).eval() for f in ckpts] | ||
| if genes is not None: | ||
| genes = genes.split(",") | ||
|
|
||
| print("Computing predictions") | ||
| preds = ( | ||
| np.stack([model.predict_on_dataset(ds, devices=0, batch_size=6, num_workers=16) for model in models]).mean(0).T | ||
| ) | ||
| ad.layers["preds"] = preds | ||
| if save_replicates and (model != "ensemble"): | ||
| raise ValueError("`--save-replicates` is only supported for ensemble model (`--model ensemble`).") | ||
|
|
||
| print("Computing correlations per gene") | ||
| ad.var["pearson"] = [np.corrcoef(ad.X[:, i], ad.layers["preds"][:, i])[0, 1] for i in range(ad.shape[1])] | ||
| ad.var["size_factor_pearson"] = [np.corrcoef(ad.X[:, i], ad.obs["size_factor"])[0, 1] for i in range(ad.shape[1])] | ||
| print( | ||
| f"Mean Pearson Correlation per gene: True: {ad.var.pearson.mean().round(2)} Size Factor: {ad.var.size_factor_pearson.mean().round(2)}" | ||
| ad = predict_gene_expression( | ||
| genes=genes, | ||
| model=model, | ||
| metadata_anndata=metadata, | ||
| device=device, | ||
| batch_size=batch_size, | ||
| num_workers=num_workers, | ||
| max_seq_shift=max_seq_shift, | ||
| genome=genome, | ||
| save_replicates=save_replicates, | ||
| ) | ||
|
|
||
| print("Computing correlation per track") | ||
| for dataset in ad.var.dataset.unique(): | ||
| key = f"{dataset}_pearson" | ||
| ad.obs[key] = [ | ||
| np.corrcoef( | ||
| ad[i, ad.var.dataset == dataset].X, | ||
| ad[i, ad.var.dataset == dataset].layers["preds"], | ||
| )[0, 1] | ||
| for i in range(ad.shape[0]) | ||
| ] | ||
| print(f"Mean Pearson Correlation per pseudobulk over {dataset} genes: {ad.obs[key].mean().round(2)}") | ||
|
|
||
| print("Saved") | ||
| ad.write_h5ad(out_file) | ||
| ad.write_h5ad(output) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.