77
88RBP_COUNT = 279
99FIX_SEQ_LEN = 4000
10+
11+
1012class APAData (Dataset ):
1113 """
1214 APAData is a dataset class for APA-Net model.
@@ -16,12 +18,21 @@ class APAData(Dataset):
1618 ct (DataFrame): Cell type profiles.
1719 device (str): Device to use (e.g., 'cuda' or 'cpu').
1820 """
21+
1922 def __init__ (self , seqs , df , ct , device ):
2023 self .device = device
21- self .reg_label = torch .from_numpy (np .array (df [:, 3 ].tolist (), dtype = np .float32 )).to (device )
22- self .seq_idx = torch .from_numpy (np .array (df [:, 1 ].tolist (), dtype = np .int32 )).to (device )
23- self .oneH_seqs = torch .from_numpy (np .array (list (seqs [:, 3 ]), dtype = np .int8 )).to (device )
24- self .oneH_seq_indexes = torch .from_numpy (np .array (seqs [:, 0 ], dtype = np .int32 )).to (device )
24+ self .reg_label = torch .from_numpy (
25+ np .array (df [:, 3 ].tolist (), dtype = np .float32 )
26+ ).to (device )
27+ self .seq_idx = torch .from_numpy (np .array (df [:, 1 ].tolist (), dtype = np .int32 )).to (
28+ device
29+ )
30+ self .oneH_seqs = torch .from_numpy (np .array (list (seqs [:, 3 ]), dtype = np .int8 )).to (
31+ device
32+ )
33+ self .oneH_seq_indexes = torch .from_numpy (
34+ np .array (seqs [:, 0 ], dtype = np .int32 )
35+ ).to (device )
2536 self .celltypes = df [:, 2 ]
2637 self .ct_profiles = ct
2738
@@ -30,10 +41,16 @@ def __len__(self):
3041
3142 def __getitem__ (self , idx ):
3243 seq_idx = self .seq_idx [idx ]
33- seq = self .oneH_seqs [torch .where (self .oneH_seq_indexes == seq_idx )].squeeze ().type (torch .cuda .FloatTensor )
44+ seq = (
45+ self .oneH_seqs [torch .where (self .oneH_seq_indexes == seq_idx )]
46+ .squeeze ()
47+ .type (torch .cuda .FloatTensor )
48+ )
3449 reg_label = self .reg_label [idx ]
3550 celltype_name = self .celltypes [idx ]
36- celltype = torch .from_numpy (self .ct_profiles [celltype_name ].values .astype (np .float32 )).to (self .device )
51+ celltype = torch .from_numpy (
52+ self .ct_profiles [celltype_name ].values .astype (np .float32 )
53+ ).to (self .device )
3754 return (seq , celltype , celltype_name , reg_label )
3855
3956
@@ -42,52 +59,61 @@ class APANET(nn.Module):
4259 APANET is a deep neural network for APA-Net.
4360 Includes Convolutional, Attention, and Fully Connected blocks.
4461 """
62+
4563 def __init__ (self , config ):
4664 super (APANET , self ).__init__ ()
4765 self .config = config
48- self .device = config [' device' ]
66+ self .device = config [" device" ]
4967 self ._build_model ()
5068
5169 def _build_model (self ):
5270 # Convolutional Block
5371 self .conv_block_1 = ConvBlock (
5472 in_channel = 4 ,
55- out_channel = self .config [' conv1kc' ],
56- cnvks = self .config [' conv1ks' ],
57- cnvst = self .config [' conv1st' ],
58- poolks = self .config [' pool1ks' ],
59- poolst = self .config [' pool1st' ],
60- pdropout = self .config [' cnvpdrop1' ],
73+ out_channel = self .config [" conv1kc" ],
74+ cnvks = self .config [" conv1ks" ],
75+ cnvst = self .config [" conv1st" ],
76+ poolks = self .config [" pool1ks" ],
77+ poolst = self .config [" pool1st" ],
78+ pdropout = self .config [" cnvpdrop1" ],
6179 activation_t = "ELU" ,
6280 )
6381 # Calculate output length after Convolution
64- cnv1_len = self ._get_conv1d_out_length (FIX_SEQ_LEN , self .config ['conv1ks' ], self .config ['conv1st' ], self .config ['pool1ks' ], self .config ['pool1st' ])
82+ cnv1_len = self ._get_conv1d_out_length (
83+ FIX_SEQ_LEN ,
84+ self .config ["conv1ks" ],
85+ self .config ["conv1st" ],
86+ self .config ["pool1ks" ],
87+ self .config ["pool1st" ],
88+ )
6589
6690 # Attention Block
6791 self .attention = nn .MultiheadAttention (
68- embed_dim = self .config [' conv1kc' ],
69- num_heads = self .config [' Matt_heads' ],
70- dropout = self .config [' Matt_drop' ]
92+ embed_dim = self .config [" conv1kc" ],
93+ num_heads = self .config [" Matt_heads" ],
94+ dropout = self .config [" Matt_drop" ],
7195 )
7296
7397 # Fully Connected Blocks
74- fc1_L1 = cnv1_len * self .config [' conv1kc' ]
98+ fc1_L1 = cnv1_len * self .config [" conv1kc" ]
7599 self .fc1 = FCBlock (
76- layer_dims = [fc1_L1 , * self .config [' fc1_dims' ]],
77- dropouts = self .config [' fc1_dropouts' ],
100+ layer_dims = [fc1_L1 , * self .config [" fc1_dims" ]],
101+ dropouts = self .config [" fc1_dropouts" ],
78102 dropout = True ,
79103 )
80104
81- fc2_L1 = self .config [' fc1_dims' ][- 1 ] + RBP_COUNT
105+ fc2_L1 = self .config [" fc1_dims" ][- 1 ] + RBP_COUNT
82106 self .fc2 = FCBlock (
83- layer_dims = [fc2_L1 , * self .config [' fc2_dims' ]],
84- dropouts = self .config [' fc2_dropouts' ],
107+ layer_dims = [fc2_L1 , * self .config [" fc2_dims" ]],
108+ dropouts = self .config [" fc2_dropouts" ],
85109 dropout = True ,
86110 )
87111
88112 def _get_conv1d_out_length (self , l_in , kernel , stride , pool_kernel , pool_stride ):
89- """ Utility method to calculate output length of Conv1D layer. """
90- length_after_conv = (l_in + 2 * (kernel // 2 ) - 1 * (kernel - 1 ) - 1 ) // stride + 1
113+ """Utility method to calculate output length of Conv1D layer."""
114+ length_after_conv = (
115+ l_in + 2 * (kernel // 2 ) - 1 * (kernel - 1 ) - 1
116+ ) // stride + 1
91117 return (length_after_conv - pool_kernel ) // pool_stride + 1
92118
93119 def forward (self , seq , celltype ):
@@ -104,15 +130,15 @@ def forward(self, seq, celltype):
104130 return x
105131
106132 def compile (self ):
107- """ Compile the model with optimizer and loss function. """
133+ """Compile the model with optimizer and loss function."""
108134 self .to (self .device )
109- if self .config [' opt' ] == "Adam" :
135+ if self .config [" opt" ] == "Adam" :
110136 self .optimizer = optim .AdamW (
111137 self .parameters (),
112- weight_decay = self .config [' adam_weight_decay' ],
113- lr = self .config ['lr' ]
138+ weight_decay = self .config [" adam_weight_decay" ],
139+ lr = self .config ["lr" ],
114140 )
115- if self .config [' loss' ] == "mse" :
141+ if self .config [" loss" ] == "mse" :
116142 self .loss_fn = nn .MSELoss ()
117143
118144 def save_model (self , filename ):
0 commit comments