Skip to content

Commit 01ddcc3

Browse files
committed
Publishing APA-Net
1 parent a51c245 commit 01ddcc3

File tree

6 files changed

+454
-0
lines changed

6 files changed

+454
-0
lines changed

README.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# APA-Net
2+
3+
APA-Net is a deep learning model designed for [brief description of the model's purpose or use case]. This guide covers the steps necessary to set up and run APA-Net.
4+
5+
## Installation
6+
7+
Before running APA-Net, ensure you have Python installed on your system. Clone this repository to your local machine:
8+
9+
```bash
10+
git clone https://github.com/yourusername/APA-Net.git
11+
cd APA-Net
12+
13+
conda env create -f environments/environment.yml
14+
conda activate apa-net-env
15+
16+
pip install .
17+
18+
```
19+
20+
# Usage
21+
22+
To train the APA-Net model, use the train_script.py script with the necessary command-line arguments:
23+
24+
```bash
25+
python train_script.py \
26+
--train_data "/path/to/train_data.npy" \
27+
--train_seq "/path/to/train_seq.npy" \
28+
--valid_data "/path/to/valid_data.npy" \
29+
--valid_seq "/path/to/valid_seq.npy" \
30+
--profiles "/path/to/celltype_profiles.tsv" \
31+
--modelfile "/path/to/model_output.pt" \
32+
--batch_size 64 \
33+
--epochs 200 \
34+
--project_name "APA-Net_Training" \
35+
--device "cuda:1" \
36+
--use_wandb "True"
37+
```
38+
39+
# Arguments
40+
- `--train_data`: Path to the training data file.
41+
- `--train_seq`: Path to the training sequence data file.
42+
- `--valid_data`: Path to the validation data file.
43+
- `--valid_seq`: Path to the validation sequence data file.
44+
- `--profiles`: Path to the cell type profiles file.
45+
- `--modelfile`: Path where the trained model will be saved.
46+
- `--batch_size`: Batch size for training (default: 64).
47+
- `--epochs`: Number of training epochs (default: 200).
48+
- `--project_name`: Name of the project for wandb logging.
49+
- `--device`: Device to run the training on (e.g., 'cuda:1').
50+
- `--use_wandb`: Flag to enable or disable wandb logging ('True' or 'False').
51+

apamodel/__init__.py

Whitespace-only changes.

apamodel/blocks.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
class ConvBlock(nn.Module):
6+
"""
7+
Convolutional Block for neural networks.
8+
Args:
9+
in_channel (int): Number of input channels.
10+
out_channel (int): Number of output channels.
11+
...
12+
"""
13+
def __init__(
14+
self,
15+
in_channel,
16+
out_channel,
17+
cnvks=1,
18+
cnvst=1,
19+
poolks=1,
20+
poolst=1,
21+
pdropout=0,
22+
activation_t="none",
23+
):
24+
super(ConvBlock, self).__init__()
25+
activations = {
26+
"ELU": nn.ELU(),
27+
"LeakyReLU": nn.LeakyReLU(),
28+
"none": nn.Identity(),
29+
}
30+
self.op = nn.Sequential(
31+
nn.Conv1d(in_channel, out_channel, cnvks, cnvst, padding=cnvks // 2),
32+
nn.BatchNorm1d(out_channel),
33+
activations[activation_t],
34+
nn.MaxPool1d(kernel_size=poolks, stride=poolst),
35+
nn.Dropout(p=pdropout),
36+
)
37+
38+
def forward(self, x):
39+
return self.op(x)
40+
41+
42+
class FCBlock(nn.Module):
43+
"""
44+
Fully Connected Block for neural networks.
45+
Args:
46+
layer_dims (list): Dimensions of layers.
47+
dropouts (list): Dropout values for layers.
48+
dropout (bool): Whether to apply dropout.
49+
"""
50+
def __init__(self, layer_dims, dropouts, dropout=False):
51+
super(FCBlock, self).__init__()
52+
layers = []
53+
for i in range(len(layer_dims) - 1):
54+
layers.append(nn.Linear(layer_dims[i], layer_dims[i + 1]))
55+
if i < len(layer_dims) - 2:
56+
layers.append(nn.BatchNorm1d(num_features=layer_dims[i + 1]))
57+
layers.append(nn.ReLU())
58+
if dropout:
59+
layers.append(nn.Dropout(p=dropouts[i]))
60+
self.op = nn.Sequential(*layers)
61+
62+
def forward(self, x):
63+
return self.op(x)

apamodel/model.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
from torch.utils.data import Dataset, DataLoader
5+
from blocks import ConvBlock, FCBlock
6+
import numpy as np
7+
8+
RBP_COUNT = 279
9+
FIX_SEQ_LEN = 4000
10+
class APAData(Dataset):
11+
"""
12+
APAData is a dataset class for APA-Net model.
13+
Args:
14+
seqs (Tensor): Sequences tensor.
15+
df (DataFrame): Dataframe containing sample information.
16+
ct (DataFrame): Cell type profiles.
17+
device (str): Device to use (e.g., 'cuda' or 'cpu').
18+
"""
19+
def __init__(self, seqs, df, ct, device):
20+
self.device = device
21+
self.reg_label = torch.from_numpy(np.array(df[:, 3].tolist(), dtype=np.float32)).to(device)
22+
self.seq_idx = torch.from_numpy(np.array(df[:, 1].tolist(), dtype=np.int32)).to(device)
23+
self.oneH_seqs = torch.from_numpy(np.array(list(seqs[:, 3]), dtype=np.int8)).to(device)
24+
self.oneH_seq_indexes = torch.from_numpy(np.array(seqs[:, 0], dtype=np.int32)).to(device)
25+
self.celltypes = df[:, 2]
26+
self.ct_profiles = ct
27+
28+
def __len__(self):
29+
return self.reg_label.shape[0]
30+
31+
def __getitem__(self, idx):
32+
seq_idx = self.seq_idx[idx]
33+
seq = self.oneH_seqs[torch.where(self.oneH_seq_indexes == seq_idx)].squeeze().type(torch.cuda.FloatTensor)
34+
reg_label = self.reg_label[idx]
35+
celltype_name = self.celltypes[idx]
36+
celltype = torch.from_numpy(self.ct_profiles[celltype_name].values.astype(np.float32)).to(self.device)
37+
return (seq, celltype, celltype_name, reg_label)
38+
39+
40+
class APANET(nn.Module):
41+
"""
42+
APANET is a deep neural network for APA-Net.
43+
Includes Convolutional, Attention, and Fully Connected blocks.
44+
"""
45+
def __init__(self, config):
46+
super(APANET, self).__init__()
47+
self.config = config
48+
self.device = config['device']
49+
self._build_model()
50+
51+
def _build_model(self):
52+
# Convolutional Block
53+
self.conv_block_1 = ConvBlock(
54+
in_channel=4,
55+
out_channel=self.config['conv1kc'],
56+
cnvks=self.config['conv1ks'],
57+
cnvst=self.config['conv1st'],
58+
poolks=self.config['pool1ks'],
59+
poolst=self.config['pool1st'],
60+
pdropout=self.config['cnvpdrop1'],
61+
activation_t="ELU",
62+
)
63+
# Calculate output length after Convolution
64+
cnv1_len = self._get_conv1d_out_length(FIX_SEQ_LEN, self.config['conv1ks'], self.config['conv1st'], self.config['pool1ks'], self.config['pool1st'])
65+
66+
# Attention Block
67+
self.attention = nn.MultiheadAttention(
68+
embed_dim=self.config['conv1kc'],
69+
num_heads=self.config['Matt_heads'],
70+
dropout=self.config['Matt_drop']
71+
)
72+
73+
# Fully Connected Blocks
74+
fc1_L1 = cnv1_len * self.config['conv1kc']
75+
self.fc1 = FCBlock(
76+
layer_dims=[fc1_L1, *self.config['fc1_dims']],
77+
dropouts=self.config['fc1_dropouts'],
78+
dropout=True,
79+
)
80+
81+
fc2_L1 = self.config['fc1_dims'][-1] + RBP_COUNT
82+
self.fc2 = FCBlock(
83+
layer_dims=[fc2_L1, *self.config['fc2_dims']],
84+
dropouts=self.config['fc2_dropouts'],
85+
dropout=True,
86+
)
87+
88+
def _get_conv1d_out_length(self, l_in, kernel, stride, pool_kernel, pool_stride):
89+
""" Utility method to calculate output length of Conv1D layer. """
90+
length_after_conv = (l_in + 2 * (kernel // 2) - 1 * (kernel - 1) - 1) // stride + 1
91+
return (length_after_conv - pool_kernel) // pool_stride + 1
92+
93+
def forward(self, seq, celltype):
94+
# Convolutional forward
95+
x_conv = self.conv_block_1(seq)
96+
x = x_conv.permute(2, 0, 1) # reshape for attention block
97+
x, _ = self.attention(x, x, x)
98+
x = x.permute(1, 2, 0) # reshape back
99+
x = x + x_conv # add residual connection
100+
x = torch.flatten(x, 1) # flatten for FC layers
101+
x = self.fc1(x) # FC block 1
102+
x = torch.cat((x, celltype), 1) # concat with celltype profile
103+
x = self.fc2(x) # FC block 2
104+
return x
105+
106+
def compile(self):
107+
""" Compile the model with optimizer and loss function. """
108+
self.to(self.device)
109+
if self.config['opt'] == "Adam":
110+
self.optimizer = optim.AdamW(
111+
self.parameters(),
112+
weight_decay=self.config['adam_weight_decay'],
113+
lr=self.config['lr']
114+
)
115+
if self.config['loss'] == "mse":
116+
self.loss_fn = nn.MSELoss()
117+
118+
def save_model(self, filename):
119+
torch.save(self.state_dict(), filename)
120+
121+
def load_model(self, filename):
122+
self.load_state_dict(torch.load(filename))

0 commit comments

Comments
 (0)