Skip to content

Kirigami: RNA secondary structure prediction via deep learning

License

Notifications You must be signed in to change notification settings

Buddha7771/kirigami

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

323 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Kirigami

arXiv

Kirigami: large convolutional kernels improve deep learning-based RNA secondary structure prediction

RNA secondary structure prediction via deep learning.

From Wikipedia:

Kirigami (切り紙) is a variation of origami, the Japanese art of folding paper. In kirigami, the paper is cut as well as being folded, resulting in a three-dimensional design that stands away from the page.

The Kirigami pipeline both folds RNA molecules via a fully convolutional neural network (FCN) and uses Nussinov-style dynamic programming to recursively cut them into subsequences for pre- and post-processing.

Kirigami obtains state-of-the-art (SOTA) performance on the bpRNA test set. Below is the table of results of Kirigami and other SOTA models on TS0:

Overview

For ease of use and reproducibility, all scripts are written idiomatically according to the Lightning specification for PyTorch with as little application-specific code as possible. The five principal classes comprising the module are:

  1. kirigami.layers.ResNet: A standard torch.nn.Module comprising the main network;
  2. kirigami.layers.Greedy: A simple constraint-satisfaction problem (CSP) solver that enforces constraints on the output matrix;
  3. kirigami.data.DataModule: A subclass of LightningDataModule that downloads, embeds, pickles, loads, and collates samples;
  4. kirigami.data.KirigamiModule: A subclass of LightningModule that wraps kirigami.layers.ResNet and includes a small number of hooks for reproducible logging, checkpointing, loops for training, etc.
  5. kirigami.writer.DbnWriter: A subclass of BasePredictionWriter that writes predicted tensors to files in dot-bracket notation.
├── LICENSE
├── README.md
├── configs
│   ├── predict.yaml
│   └── test.yaml
├── data
│   ├── bpRNA
│   │   ├── TR0.dbn
│   │   ├── TS0.dbn
│   │   └── VL0.dbn
│   └── predict
│       └── input
│           └── bpRNA_CRW_15573.fasta
├── kirigami
│   ├── __init__.py
│   ├── constants.py
│   ├── data.py
│   ├── layers.py
│   ├── learner.py
│   ├── nussinov.pyx
│   ├── utils.py
│   └── writer.py
├── requirements.txt
├── run.py
└── weights
    └── main.ckpt

Installation

Note that git-lfs is required to download the weights. Otherwise, no specific setup is necessary. For example, one might run:

git clone https://github.com/marc-harary/kirigami
cd kirigami
git lfs pull
python3 -m venv kirigami-venv
source kirigami-venv/bin/activate
pip3 install -r requirements.txt
python run.py predict && cat data/predict/output/bpRNA_CRW_15573.dbn

Usage

The primary entrypoint for Kirigami is the LightningCLI, which allows for retraining or fine-tuning the model, testing it on the benchmark datasets, predicting novel structures, etc. It is used as follows:

python run.py --help
usage: run.py [-h] [-c CONFIG] [--print_config[=flags]] {fit,validate,test,predict} ...

pytorch-lightning trainer command line tool

options:
  -h, --help            Show this help message and exit.
  -c CONFIG, --config CONFIG
                        Path to a configuration file in json or yaml format.
  --print_config[=flags]
                        Print the configuration after applying all other arguments and exit. The optional flags customizes the output and are one or
                        more keywords separated by comma. The supported flags are: comments, skip_default, skip_null.

subcommands:
  For more details of each subcommand, add it as an argument followed by --help.

  {fit,validate,test,predict}
    fit                 Runs the full optimization routine.
    validate            Perform one evaluation epoch over the validation set.
    test                Perform one evaluation epoch over the test set.
    predict             Run inference on your data.

Default configuration yaml files in the configs directory, which in turn point to the weights stored in weights/main.ckpt.

Prediction

Please write all inputs in standard FASTA format to data/predict/input and then call the KirigamiModule.predict method simply by entering:

python run.py predict

Correspondingly named dbn files containing the predicted secondary strucure will be written to data/predict/output. An example file is located in data/predict/input/bpRNA_CRW_15573.fasta.

Testing

Running

python run.py test

will evaluate Kirigami on each molecule in TS0, compute accuracy metrics, and output the averages to the terminal.

(Re)training

Although the weights of the production model are located at weights/main.ckpt, Kirigami can be retrained with varying hyperparameters. Simply run

python run.py fit --help

for an exhaustive list of configurations.

To perform an exact, globally seeded replication of the experiment that generated the weights, run

python run.py fit

to use the appropriate configuration file.

Data

Data used for training, validation, and testing are taken from the bpRNA database in the form of the standard TR0, VL0, and TS0 datasets used by SPOT-RNA, MXfold2, and UFold. Respectively, these contain 10,814, 1,300, and 1,305 non-redundant structures. The .dbn files located in this repo were generated by scraping the data originally uploaded by the authors of SPOT-RNA.

The data are currently written to dbn files in data/bpRNA/TR0.dbn, data/bpRNA/VL0.dbn, data/bpRNA/TS0.dbn but will be embedded as torch.Tensors by the kirigami.data.DataModule.prepare_data method once any of the LightningCLI subcommands are run. Please see the documentation for the LightningDataModule API for more detail.

Model architecture

Kirigami consists of an extremely simple residual neural network (RNN) architecture that can be found in kirigami/layers.py, with primary network API being kirigami.layers.ResNet. The hyperparameters for the model are as follows:

<class 'kirigami.learner.KirigamiModule'>:
  --model.n_blocks N_BLOCKS
  		(required, type: int)
  --model.n_channels N_CHANNELS
  		(required, type: int)
  --model.kernel_sizes [ITEM,...]
  		(required, type: Tuple[int, int])
  --model.dilations [ITEM,...]
  		(required, type: Tuple[int, int])
  --model.activation ACTIVATION
  		(required, type: str)
  --model.dropout DROPOUT

These are:

  1. n_blocks: The total number of residual neural network blocks (i.e., kirigami.layers.ResNetBlock);
  2. n_channels: The number of hidden channels for each block;
  3. kernel_sizes: The dimensions of the kernels for the first and second torch.nn.Conv2D layers in each block;
  4. dilations: The dilations for said convolutional layers;
  5. activation: The class name for the non-linearities in each block;
  6. dropout: The dropout probability for the torch.nn.Dropout layer in each block.

Respectively, the default parameters as used in weights/main.ckpt are 32, 32, (11, 11), (1, 1), "ReLU", 0.15.

About

Kirigami: RNA secondary structure prediction via deep learning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%