Skip to content

Commit f6efd04

Browse files
committed
publishing APA-Net
1 parent 01ddcc3 commit f6efd04

File tree

3 files changed

+182
-90
lines changed

3 files changed

+182
-90
lines changed

apamodel/blocks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5+
56
class ConvBlock(nn.Module):
67
"""
78
Convolutional Block for neural networks.
@@ -10,6 +11,7 @@ class ConvBlock(nn.Module):
1011
out_channel (int): Number of output channels.
1112
...
1213
"""
14+
1315
def __init__(
1416
self,
1517
in_channel,
@@ -47,6 +49,7 @@ class FCBlock(nn.Module):
4749
dropouts (list): Dropout values for layers.
4850
dropout (bool): Whether to apply dropout.
4951
"""
52+
5053
def __init__(self, layer_dims, dropouts, dropout=False):
5154
super(FCBlock, self).__init__()
5255
layers = []

apamodel/model.py

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
RBP_COUNT = 279
99
FIX_SEQ_LEN = 4000
10+
11+
1012
class APAData(Dataset):
1113
"""
1214
APAData is a dataset class for APA-Net model.
@@ -16,12 +18,21 @@ class APAData(Dataset):
1618
ct (DataFrame): Cell type profiles.
1719
device (str): Device to use (e.g., 'cuda' or 'cpu').
1820
"""
21+
1922
def __init__(self, seqs, df, ct, device):
2023
self.device = device
21-
self.reg_label = torch.from_numpy(np.array(df[:, 3].tolist(), dtype=np.float32)).to(device)
22-
self.seq_idx = torch.from_numpy(np.array(df[:, 1].tolist(), dtype=np.int32)).to(device)
23-
self.oneH_seqs = torch.from_numpy(np.array(list(seqs[:, 3]), dtype=np.int8)).to(device)
24-
self.oneH_seq_indexes = torch.from_numpy(np.array(seqs[:, 0], dtype=np.int32)).to(device)
24+
self.reg_label = torch.from_numpy(
25+
np.array(df[:, 3].tolist(), dtype=np.float32)
26+
).to(device)
27+
self.seq_idx = torch.from_numpy(np.array(df[:, 1].tolist(), dtype=np.int32)).to(
28+
device
29+
)
30+
self.oneH_seqs = torch.from_numpy(np.array(list(seqs[:, 3]), dtype=np.int8)).to(
31+
device
32+
)
33+
self.oneH_seq_indexes = torch.from_numpy(
34+
np.array(seqs[:, 0], dtype=np.int32)
35+
).to(device)
2536
self.celltypes = df[:, 2]
2637
self.ct_profiles = ct
2738

@@ -30,10 +41,16 @@ def __len__(self):
3041

3142
def __getitem__(self, idx):
3243
seq_idx = self.seq_idx[idx]
33-
seq = self.oneH_seqs[torch.where(self.oneH_seq_indexes == seq_idx)].squeeze().type(torch.cuda.FloatTensor)
44+
seq = (
45+
self.oneH_seqs[torch.where(self.oneH_seq_indexes == seq_idx)]
46+
.squeeze()
47+
.type(torch.cuda.FloatTensor)
48+
)
3449
reg_label = self.reg_label[idx]
3550
celltype_name = self.celltypes[idx]
36-
celltype = torch.from_numpy(self.ct_profiles[celltype_name].values.astype(np.float32)).to(self.device)
51+
celltype = torch.from_numpy(
52+
self.ct_profiles[celltype_name].values.astype(np.float32)
53+
).to(self.device)
3754
return (seq, celltype, celltype_name, reg_label)
3855

3956

@@ -42,52 +59,61 @@ class APANET(nn.Module):
4259
APANET is a deep neural network for APA-Net.
4360
Includes Convolutional, Attention, and Fully Connected blocks.
4461
"""
62+
4563
def __init__(self, config):
4664
super(APANET, self).__init__()
4765
self.config = config
48-
self.device = config['device']
66+
self.device = config["device"]
4967
self._build_model()
5068

5169
def _build_model(self):
5270
# Convolutional Block
5371
self.conv_block_1 = ConvBlock(
5472
in_channel=4,
55-
out_channel=self.config['conv1kc'],
56-
cnvks=self.config['conv1ks'],
57-
cnvst=self.config['conv1st'],
58-
poolks=self.config['pool1ks'],
59-
poolst=self.config['pool1st'],
60-
pdropout=self.config['cnvpdrop1'],
73+
out_channel=self.config["conv1kc"],
74+
cnvks=self.config["conv1ks"],
75+
cnvst=self.config["conv1st"],
76+
poolks=self.config["pool1ks"],
77+
poolst=self.config["pool1st"],
78+
pdropout=self.config["cnvpdrop1"],
6179
activation_t="ELU",
6280
)
6381
# Calculate output length after Convolution
64-
cnv1_len = self._get_conv1d_out_length(FIX_SEQ_LEN, self.config['conv1ks'], self.config['conv1st'], self.config['pool1ks'], self.config['pool1st'])
82+
cnv1_len = self._get_conv1d_out_length(
83+
FIX_SEQ_LEN,
84+
self.config["conv1ks"],
85+
self.config["conv1st"],
86+
self.config["pool1ks"],
87+
self.config["pool1st"],
88+
)
6589

6690
# Attention Block
6791
self.attention = nn.MultiheadAttention(
68-
embed_dim=self.config['conv1kc'],
69-
num_heads=self.config['Matt_heads'],
70-
dropout=self.config['Matt_drop']
92+
embed_dim=self.config["conv1kc"],
93+
num_heads=self.config["Matt_heads"],
94+
dropout=self.config["Matt_drop"],
7195
)
7296

7397
# Fully Connected Blocks
74-
fc1_L1 = cnv1_len * self.config['conv1kc']
98+
fc1_L1 = cnv1_len * self.config["conv1kc"]
7599
self.fc1 = FCBlock(
76-
layer_dims=[fc1_L1, *self.config['fc1_dims']],
77-
dropouts=self.config['fc1_dropouts'],
100+
layer_dims=[fc1_L1, *self.config["fc1_dims"]],
101+
dropouts=self.config["fc1_dropouts"],
78102
dropout=True,
79103
)
80104

81-
fc2_L1 = self.config['fc1_dims'][-1] + RBP_COUNT
105+
fc2_L1 = self.config["fc1_dims"][-1] + RBP_COUNT
82106
self.fc2 = FCBlock(
83-
layer_dims=[fc2_L1, *self.config['fc2_dims']],
84-
dropouts=self.config['fc2_dropouts'],
107+
layer_dims=[fc2_L1, *self.config["fc2_dims"]],
108+
dropouts=self.config["fc2_dropouts"],
85109
dropout=True,
86110
)
87111

88112
def _get_conv1d_out_length(self, l_in, kernel, stride, pool_kernel, pool_stride):
89-
""" Utility method to calculate output length of Conv1D layer. """
90-
length_after_conv = (l_in + 2 * (kernel // 2) - 1 * (kernel - 1) - 1) // stride + 1
113+
"""Utility method to calculate output length of Conv1D layer."""
114+
length_after_conv = (
115+
l_in + 2 * (kernel // 2) - 1 * (kernel - 1) - 1
116+
) // stride + 1
91117
return (length_after_conv - pool_kernel) // pool_stride + 1
92118

93119
def forward(self, seq, celltype):
@@ -104,15 +130,15 @@ def forward(self, seq, celltype):
104130
return x
105131

106132
def compile(self):
107-
""" Compile the model with optimizer and loss function. """
133+
"""Compile the model with optimizer and loss function."""
108134
self.to(self.device)
109-
if self.config['opt'] == "Adam":
135+
if self.config["opt"] == "Adam":
110136
self.optimizer = optim.AdamW(
111137
self.parameters(),
112-
weight_decay=self.config['adam_weight_decay'],
113-
lr=self.config['lr']
138+
weight_decay=self.config["adam_weight_decay"],
139+
lr=self.config["lr"],
114140
)
115-
if self.config['loss'] == "mse":
141+
if self.config["loss"] == "mse":
116142
self.loss_fn = nn.MSELoss()
117143

118144
def save_model(self, filename):

0 commit comments

Comments
 (0)