|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.optim as optim |
| 4 | +from torch.utils.data import Dataset, DataLoader |
| 5 | +from blocks import ConvBlock, FCBlock |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +RBP_COUNT = 279 |
| 9 | +FIX_SEQ_LEN = 4000 |
| 10 | +class APAData(Dataset): |
| 11 | + """ |
| 12 | + APAData is a dataset class for APA-Net model. |
| 13 | + Args: |
| 14 | + seqs (Tensor): Sequences tensor. |
| 15 | + df (DataFrame): Dataframe containing sample information. |
| 16 | + ct (DataFrame): Cell type profiles. |
| 17 | + device (str): Device to use (e.g., 'cuda' or 'cpu'). |
| 18 | + """ |
| 19 | + def __init__(self, seqs, df, ct, device): |
| 20 | + self.device = device |
| 21 | + self.reg_label = torch.from_numpy(np.array(df[:, 3].tolist(), dtype=np.float32)).to(device) |
| 22 | + self.seq_idx = torch.from_numpy(np.array(df[:, 1].tolist(), dtype=np.int32)).to(device) |
| 23 | + self.oneH_seqs = torch.from_numpy(np.array(list(seqs[:, 3]), dtype=np.int8)).to(device) |
| 24 | + self.oneH_seq_indexes = torch.from_numpy(np.array(seqs[:, 0], dtype=np.int32)).to(device) |
| 25 | + self.celltypes = df[:, 2] |
| 26 | + self.ct_profiles = ct |
| 27 | + |
| 28 | + def __len__(self): |
| 29 | + return self.reg_label.shape[0] |
| 30 | + |
| 31 | + def __getitem__(self, idx): |
| 32 | + seq_idx = self.seq_idx[idx] |
| 33 | + seq = self.oneH_seqs[torch.where(self.oneH_seq_indexes == seq_idx)].squeeze().type(torch.cuda.FloatTensor) |
| 34 | + reg_label = self.reg_label[idx] |
| 35 | + celltype_name = self.celltypes[idx] |
| 36 | + celltype = torch.from_numpy(self.ct_profiles[celltype_name].values.astype(np.float32)).to(self.device) |
| 37 | + return (seq, celltype, celltype_name, reg_label) |
| 38 | + |
| 39 | + |
| 40 | +class APANET(nn.Module): |
| 41 | + """ |
| 42 | + APANET is a deep neural network for APA-Net. |
| 43 | + Includes Convolutional, Attention, and Fully Connected blocks. |
| 44 | + """ |
| 45 | + def __init__(self, config): |
| 46 | + super(APANET, self).__init__() |
| 47 | + self.config = config |
| 48 | + self.device = config['device'] |
| 49 | + self._build_model() |
| 50 | + |
| 51 | + def _build_model(self): |
| 52 | + # Convolutional Block |
| 53 | + self.conv_block_1 = ConvBlock( |
| 54 | + in_channel=4, |
| 55 | + out_channel=self.config['conv1kc'], |
| 56 | + cnvks=self.config['conv1ks'], |
| 57 | + cnvst=self.config['conv1st'], |
| 58 | + poolks=self.config['pool1ks'], |
| 59 | + poolst=self.config['pool1st'], |
| 60 | + pdropout=self.config['cnvpdrop1'], |
| 61 | + activation_t="ELU", |
| 62 | + ) |
| 63 | + # Calculate output length after Convolution |
| 64 | + cnv1_len = self._get_conv1d_out_length(FIX_SEQ_LEN, self.config['conv1ks'], self.config['conv1st'], self.config['pool1ks'], self.config['pool1st']) |
| 65 | + |
| 66 | + # Attention Block |
| 67 | + self.attention = nn.MultiheadAttention( |
| 68 | + embed_dim=self.config['conv1kc'], |
| 69 | + num_heads=self.config['Matt_heads'], |
| 70 | + dropout=self.config['Matt_drop'] |
| 71 | + ) |
| 72 | + |
| 73 | + # Fully Connected Blocks |
| 74 | + fc1_L1 = cnv1_len * self.config['conv1kc'] |
| 75 | + self.fc1 = FCBlock( |
| 76 | + layer_dims=[fc1_L1, *self.config['fc1_dims']], |
| 77 | + dropouts=self.config['fc1_dropouts'], |
| 78 | + dropout=True, |
| 79 | + ) |
| 80 | + |
| 81 | + fc2_L1 = self.config['fc1_dims'][-1] + RBP_COUNT |
| 82 | + self.fc2 = FCBlock( |
| 83 | + layer_dims=[fc2_L1, *self.config['fc2_dims']], |
| 84 | + dropouts=self.config['fc2_dropouts'], |
| 85 | + dropout=True, |
| 86 | + ) |
| 87 | + |
| 88 | + def _get_conv1d_out_length(self, l_in, kernel, stride, pool_kernel, pool_stride): |
| 89 | + """ Utility method to calculate output length of Conv1D layer. """ |
| 90 | + length_after_conv = (l_in + 2 * (kernel // 2) - 1 * (kernel - 1) - 1) // stride + 1 |
| 91 | + return (length_after_conv - pool_kernel) // pool_stride + 1 |
| 92 | + |
| 93 | + def forward(self, seq, celltype): |
| 94 | + # Convolutional forward |
| 95 | + x_conv = self.conv_block_1(seq) |
| 96 | + x = x_conv.permute(2, 0, 1) # reshape for attention block |
| 97 | + x, _ = self.attention(x, x, x) |
| 98 | + x = x.permute(1, 2, 0) # reshape back |
| 99 | + x = x + x_conv # add residual connection |
| 100 | + x = torch.flatten(x, 1) # flatten for FC layers |
| 101 | + x = self.fc1(x) # FC block 1 |
| 102 | + x = torch.cat((x, celltype), 1) # concat with celltype profile |
| 103 | + x = self.fc2(x) # FC block 2 |
| 104 | + return x |
| 105 | + |
| 106 | + def compile(self): |
| 107 | + """ Compile the model with optimizer and loss function. """ |
| 108 | + self.to(self.device) |
| 109 | + if self.config['opt'] == "Adam": |
| 110 | + self.optimizer = optim.AdamW( |
| 111 | + self.parameters(), |
| 112 | + weight_decay=self.config['adam_weight_decay'], |
| 113 | + lr=self.config['lr'] |
| 114 | + ) |
| 115 | + if self.config['loss'] == "mse": |
| 116 | + self.loss_fn = nn.MSELoss() |
| 117 | + |
| 118 | + def save_model(self, filename): |
| 119 | + torch.save(self.state_dict(), filename) |
| 120 | + |
| 121 | + def load_model(self, filename): |
| 122 | + self.load_state_dict(torch.load(filename)) |
0 commit comments