Kirigami: large convolutional kernels improve deep learning-based RNA secondary structure prediction
RNA secondary structure prediction via deep learning.
From Wikipedia:
Kirigami (切り紙) is a variation of origami, the Japanese art of folding paper. In kirigami, the paper is cut as well as being folded, resulting in a three-dimensional design that stands away from the page.
The Kirigami pipeline both folds RNA molecules via a fully convolutional neural network (FCN) and uses Nussinov-style dynamic programming to recursively cut them into subsequences for pre- and post-processing.
Kirigami obtains state-of-the-art (SOTA) performance on the bpRNA test set. Below is the table of results of Kirigami and other SOTA models on TS0:
For ease of use and reproducibility, all scripts are written idiomatically according to the Lightning specification for PyTorch with as little application-specific code as possible. The five principal classes comprising the module are:
kirigami.layers.ResNet: A standardtorch.nn.Modulecomprising the main network;kirigami.layers.Greedy: A simple constraint-satisfaction problem (CSP) solver that enforces constraints on the output matrix;kirigami.data.DataModule: A subclass of LightningDataModule that downloads, embeds, pickles, loads, and collates samples;kirigami.data.KirigamiModule: A subclass of LightningModule that wrapskirigami.layers.ResNetand includes a small number of hooks for reproducible logging, checkpointing, loops for training, etc.kirigami.writer.DbnWriter: A subclass of BasePredictionWriter that writes predicted tensors to files in dot-bracket notation.
├── LICENSE
├── README.md
├── configs
│ ├── predict.yaml
│ └── test.yaml
├── data
│ ├── bpRNA
│ │ ├── TR0.dbn
│ │ ├── TS0.dbn
│ │ └── VL0.dbn
│ └── predict
│ └── input
│ └── bpRNA_CRW_15573.fasta
├── kirigami
│ ├── __init__.py
│ ├── constants.py
│ ├── data.py
│ ├── layers.py
│ ├── learner.py
│ ├── nussinov.pyx
│ ├── utils.py
│ └── writer.py
├── requirements.txt
├── run.py
└── weights
└── main.ckptNote that git-lfs is required to download the weights. Otherwise, no specific setup is necessary. For example, one might run:
git clone https://github.com/marc-harary/kirigami
cd kirigami
git lfs pull
python3 -m venv kirigami-venv
source kirigami-venv/bin/activate
pip3 install -r requirements.txt
python run.py predict && cat data/predict/output/bpRNA_CRW_15573.dbnThe primary entrypoint for Kirigami is the LightningCLI, which allows for retraining or fine-tuning the model, testing it on the benchmark datasets, predicting novel structures, etc. It is used as follows:
python run.py --help
usage: run.py [-h] [-c CONFIG] [--print_config[=flags]] {fit,validate,test,predict} ...
pytorch-lightning trainer command line tool
options:
-h, --help Show this help message and exit.
-c CONFIG, --config CONFIG
Path to a configuration file in json or yaml format.
--print_config[=flags]
Print the configuration after applying all other arguments and exit. The optional flags customizes the output and are one or
more keywords separated by comma. The supported flags are: comments, skip_default, skip_null.
subcommands:
For more details of each subcommand, add it as an argument followed by --help.
{fit,validate,test,predict}
fit Runs the full optimization routine.
validate Perform one evaluation epoch over the validation set.
test Perform one evaluation epoch over the test set.
predict Run inference on your data.Default configuration yaml files in the configs directory, which in turn point to the weights stored in weights/main.ckpt.
Please write all inputs in standard FASTA format to data/predict/input and then call the KirigamiModule.predict method simply by entering:
python run.py predictCorrespondingly named dbn files containing the predicted secondary strucure will be written to data/predict/output. An example file is located in data/predict/input/bpRNA_CRW_15573.fasta.
Running
python run.py testwill evaluate Kirigami on each molecule in TS0, compute accuracy metrics, and output the averages to the terminal.
Although the weights of the production model are located at weights/main.ckpt, Kirigami can be retrained with varying hyperparameters. Simply run
python run.py fit --helpfor an exhaustive list of configurations.
To perform an exact, globally seeded replication of the experiment that generated the weights, run
python run.py fitto use the appropriate configuration file.
Data used for training, validation, and testing are taken from the bpRNA database in the form of the standard TR0, VL0, and TS0 datasets used by SPOT-RNA, MXfold2, and UFold. Respectively, these contain 10,814, 1,300, and 1,305 non-redundant structures. The .dbn files located in this repo were generated by scraping the data originally uploaded by the authors of SPOT-RNA.
The data are currently written to dbn files in data/bpRNA/TR0.dbn, data/bpRNA/VL0.dbn, data/bpRNA/TS0.dbn but will be embedded as torch.Tensors by the kirigami.data.DataModule.prepare_data method once any of the LightningCLI subcommands are run. Please see the documentation for the LightningDataModule API for more detail.
Kirigami consists of an extremely simple residual neural network (RNN) architecture that can be found in kirigami/layers.py, with primary network API being kirigami.layers.ResNet. The hyperparameters for the model are as follows:
<class 'kirigami.learner.KirigamiModule'>:
--model.n_blocks N_BLOCKS
(required, type: int)
--model.n_channels N_CHANNELS
(required, type: int)
--model.kernel_sizes [ITEM,...]
(required, type: Tuple[int, int])
--model.dilations [ITEM,...]
(required, type: Tuple[int, int])
--model.activation ACTIVATION
(required, type: str)
--model.dropout DROPOUTThese are:
n_blocks: The total number of residual neural network blocks (i.e.,kirigami.layers.ResNetBlock);n_channels: The number of hidden channels for each block;kernel_sizes: The dimensions of the kernels for the first and secondtorch.nn.Conv2Dlayers in each block;dilations: The dilations for said convolutional layers;activation: The class name for the non-linearities in each block;dropout: The dropout probability for thetorch.nn.Dropoutlayer in each block.
Respectively, the default parameters as used in weights/main.ckpt are 32, 32, (11, 11), (1, 1), "ReLU", 0.15.