Skip to content

Commit e2ec6d8

Browse files
authored
update
1 parent 8cff8bc commit e2ec6d8

File tree

3 files changed

+90
-63
lines changed

3 files changed

+90
-63
lines changed

apamodel/blocks.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,36 @@ def __init__(self, layer_dims, dropouts, dropout=False):
6464

6565
def forward(self, x):
6666
return self.op(x)
67+
68+
69+
class ProcessSelfAttn(nn.Module):
70+
"""
71+
Implements the self-attention mechanism.
72+
Attributes
73+
----------
74+
nhead : int
75+
The number of attention heads.
76+
Each head computes a separate attention score for each token.
77+
"""
78+
79+
def __init__(
80+
self,
81+
embed_dim: int,
82+
num_layers: int,
83+
nhead: int,
84+
dim_feedforward: int = 2048,
85+
dropout: float = 0.2,
86+
):
87+
super().__init__()
88+
self.encoder_layer = nn.TransformerEncoderLayer(
89+
embed_dim,
90+
nhead,
91+
dim_feedforward,
92+
dropout,
93+
activation="gelu",
94+
batch_first=True,
95+
)
96+
self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers)
97+
98+
def forward(self, latent):
99+
return self.transformer(latent)

apamodel/model.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import torch.nn as nn
33
import torch.optim as optim
44
from torch.utils.data import Dataset, DataLoader
5-
from blocks import ConvBlock, FCBlock
5+
from blocks import ConvBlock, FCBlock, ProcessSelfAttn
66
import numpy as np
77

8-
RBP_COUNT = 279
8+
RBP_COUNT = 327
99
FIX_SEQ_LEN = 4000
1010

1111

12+
1213
class APAData(Dataset):
1314
"""
1415
APAData is a dataset class for APA-Net model.
@@ -19,40 +20,30 @@ class APAData(Dataset):
1920
device (str): Device to use (e.g., 'cuda' or 'cpu').
2021
"""
2122

22-
def __init__(self, seqs, df, ct, device):
23+
def __init__(self, data, device):
2324
self.device = device
2425
self.reg_label = torch.from_numpy(
25-
np.array(df[:, 3].tolist(), dtype=np.float32)
26+
np.array(data[:, 7].tolist(), dtype=np.float32)
2627
).to(device)
27-
self.seq_idx = torch.from_numpy(np.array(df[:, 1].tolist(), dtype=np.int32)).to(
28+
self.oneH_seqs = torch.from_numpy(np.array(data[:, 6].tolist())).to(
2829
device
2930
)
30-
self.oneH_seqs = torch.from_numpy(np.array(list(seqs[:, 3]), dtype=np.int8)).to(
31+
self.ct_profiles = torch.from_numpy(np.array(data[:, 8].tolist())).to(
3132
device
3233
)
33-
self.oneH_seq_indexes = torch.from_numpy(
34-
np.array(seqs[:, 0], dtype=np.int32)
35-
).to(device)
36-
self.celltypes = df[:, 2]
37-
self.ct_profiles = ct
34+
self.celltype_name = data[:, 1].tolist()
35+
self.switch_name = data[:, 5].tolist()
3836

3937
def __len__(self):
4038
return self.reg_label.shape[0]
4139

4240
def __getitem__(self, idx):
43-
seq_idx = self.seq_idx[idx]
44-
seq = (
45-
self.oneH_seqs[torch.where(self.oneH_seq_indexes == seq_idx)]
46-
.squeeze()
47-
.type(torch.cuda.FloatTensor)
48-
)
41+
seq = self.oneH_seqs[idx].type(torch.cuda.FloatTensor)
4942
reg_label = self.reg_label[idx]
50-
celltype_name = self.celltypes[idx]
51-
celltype = torch.from_numpy(
52-
self.ct_profiles[celltype_name].values.astype(np.float32)
53-
).to(self.device)
54-
return (seq, celltype, celltype_name, reg_label)
55-
43+
celltype_profile = self.ct_profiles[idx].type(torch.cuda.FloatTensor)
44+
celltype_name = self.celltype_name[idx]
45+
switch_name = self.switch_name[idx]
46+
return (seq, reg_label, celltype_profile, celltype_name, switch_name)
5647

5748
class APANET(nn.Module):
5849
"""
@@ -108,6 +99,13 @@ def _build_model(self):
10899
dropouts=self.config["fc2_dropouts"],
109100
dropout=True,
110101
)
102+
self.process_self_attn = ProcessSelfAttn(
103+
self.config["psa_query_dim"],
104+
self.config["psa_num_layers"],
105+
self.config["psa_nhead"],
106+
self.config["psa_dim_feedforward"],
107+
self.config["psa_dropout"]
108+
)
111109

112110
def _get_conv1d_out_length(self, l_in, kernel, stride, pool_kernel, pool_stride):
113111
"""Utility method to calculate output length of Conv1D layer."""
@@ -118,10 +116,11 @@ def _get_conv1d_out_length(self, l_in, kernel, stride, pool_kernel, pool_stride)
118116

119117
def forward(self, seq, celltype):
120118
# Convolutional forward
121-
x_conv = self.conv_block_1(seq)
122-
x = x_conv.permute(2, 0, 1) # reshape for attention block
123-
x, _ = self.attention(x, x, x)
124-
x = x.permute(1, 2, 0) # reshape back
119+
x_conv = self.conv_block_1(seq) # batch, 64/128(dim), 80(len)
120+
x = x_conv.permute(0, 2, 1) # reshape for attention block so dim is first
121+
# x, _ = self.attention(x, x, x)
122+
x = self.process_self_attn(x)
123+
x = x.permute(0, 2, 1) # reshape back
125124
x = x + x_conv # add residual connection
126125
x = torch.flatten(x, 1) # flatten for FC layers
127126
x = self.fc1(x) # FC block 1

apamodel/train_script.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
def build_dataloaders(
14-
device, train_seq, valid_seq, train_data, val_data, batch_size, ct_profiles
14+
device, train_data, valid_data, batch_size,
1515
):
1616
"""
1717
Create training and validation data loaders.
@@ -24,21 +24,27 @@ def build_dataloaders(
2424
Tuple of DataLoader for training and validation datasets.
2525
"""
2626
train_loader = DataLoader(
27-
APAData(train_seq, train_data, ct_profiles, device),
27+
APAData(train_data, device),
2828
batch_size=batch_size,
2929
shuffle=True,
3030
drop_last=True,
3131
)
3232
valid_loader = DataLoader(
33-
APAData(valid_seq, val_data, ct_profiles, device),
33+
APAData(valid_data, device),
3434
batch_size=batch_size,
3535
shuffle=False,
3636
drop_last=False,
3737
)
3838
return train_loader, valid_loader
3939

40+
def l1_penalty(model, l1_factor):
41+
l1_reg = torch.tensor(0.).to(model.device)
42+
for param in model.parameters():
43+
l1_reg += torch.norm(param, 1)
44+
return l1_factor * l1_reg
4045

41-
def train_one_epoch(model, train_loader):
46+
47+
def train_one_epoch(model, train_loader, l1_factor=0.00005):
4248
"""
4349
Train the model for one epoch.
4450
Args:
@@ -49,10 +55,13 @@ def train_one_epoch(model, train_loader):
4955
"""
5056
model.train()
5157
total_loss, predictions, targets = 0.0, [], []
52-
for seq_X, celltype, _, Y in train_loader:
58+
for seq_X, Y, celltype, _, _ in train_loader:
5359
model.optimizer.zero_grad()
5460
outputs = torch.squeeze(model(seq_X, celltype))
55-
loss = torch.sqrt(model.loss_fn(outputs, Y))
61+
mse_loss = torch.sqrt(model.loss_fn(outputs, Y))
62+
# l1_loss = l1_penalty(model, l1_factor)
63+
# loss = mse_loss + l1_loss
64+
loss = mse_loss
5665
loss.backward()
5766
model.optimizer.step()
5867
total_loss += loss.item() * seq_X.size(0)
@@ -77,7 +86,7 @@ def validate_one_epoch(model, valid_loader):
7786
model.eval()
7887
total_loss, predictions, targets = 0.0, [], []
7988
with torch.no_grad():
80-
for seq_X, celltype, _, Y in valid_loader:
89+
for seq_X, Y, celltype, _, _ in valid_loader:
8190
outputs = torch.squeeze(model(seq_X, celltype))
8291
loss = torch.sqrt(model.loss_fn(outputs, Y))
8392
total_loss += loss.item() * seq_X.size(0)
@@ -92,13 +101,11 @@ def validate_one_epoch(model, valid_loader):
92101

93102

94103
def main_train(
95-
train_seq,
96-
valid_seq,
97104
train_data,
98105
val_data,
99-
profiles,
100106
modelfile,
101107
device,
108+
project_name,
102109
config,
103110
use_wandb,
104111
):
@@ -115,18 +122,15 @@ def main_train(
115122
use_wandb = args.use_wandb.lower() == "true"
116123
train_loader, valid_loader = build_dataloaders(
117124
device,
118-
train_seq,
119-
valid_seq,
120125
train_data,
121126
val_data,
122127
config["batch_size"],
123-
profiles,
124128
)
125129
with tqdm(range(config["epochs"]), unit="epoch") as tepochs:
126130
if use_wandb:
127131
wandb.login()
128132
with wandb.init(
129-
project=config["project_name"],
133+
project= project_name,
130134
settings=wandb.Settings(start_method="thread"),
131135
):
132136
model = APANET(config)
@@ -167,18 +171,9 @@ def main_train(
167171
parser.add_argument(
168172
"--train_data", type=str, required=True, help="Path to training data file"
169173
)
170-
parser.add_argument(
171-
"--train_seq", type=str, required=True, help="Path to training sequences file"
172-
)
173174
parser.add_argument(
174175
"--valid_data", type=str, required=True, help="Path to validation data file"
175176
)
176-
parser.add_argument(
177-
"--valid_seq", type=str, required=True, help="Path to validation sequences file"
178-
)
179-
parser.add_argument(
180-
"--profiles", type=str, required=True, help="Path to cell type profiles file"
181-
)
182177
parser.add_argument(
183178
"--modelfile", type=str, required=True, help="Path to save the trained model"
184179
)
@@ -214,10 +209,7 @@ def main_train(
214209
np.random.seed(7)
215210

216211
train_data = np.load(args.train_data, allow_pickle=True)
217-
train_seq = np.load(args.train_seq, allow_pickle=True)
218212
valid_data = np.load(args.valid_data, allow_pickle=True)
219-
valid_seq = np.load(args.valid_seq, allow_pickle=True)
220-
profiles = pd.read_csv(args.profiles, index_col=0, sep="\t")
221213

222214
config = {
223215
"batch_size": args.batch_size,
@@ -227,35 +219,38 @@ def main_train(
227219
"opt": "Adam",
228220
"loss": "mse",
229221
"lr": 2.5e-05,
230-
"adam_weight_decay": 0.06,
231-
"conv1kc": 128,
222+
"adam_weight_decay": 0.09, # 0.06 before
223+
"conv1kc": 128, #128, 64
232224
"conv1ks": 12,
233225
"conv1st": 1,
234-
"pool1ks": 25,
235-
"pool1st": 25,
236-
"cnvpdrop1": 0.2,
226+
"pool1ks": 16,
227+
"pool1st": 16,
228+
"cnvpdrop1": 0,
237229
"Matt_heads": 8,
238230
"Matt_drop": 0.2,
239231
"fc1_dims": [
240-
8192,
232+
8192, # 8192, 5120
241233
4048,
242234
1024,
243235
512,
244236
256,
245237
], # first dimension will be calculated dynamically
246-
"fc1_dropouts": [0.3, 0.25, 0.25, 0.2, 0.1],
238+
"fc1_dropouts": [0.25, 0.25, 0.25, 0, 0],
247239
"fc2_dims": [128, 32, 16, 1], # first dimension will be calculated dynamically
248240
"fc2_dropouts": [0.2, 0.2, 0, 0],
241+
'psa_query_dim': 128, # make sure this is correct
242+
'psa_num_layers': 1,
243+
'psa_nhead': 1,
244+
'psa_dim_feedforward':1024,
245+
'psa_dropout': 0
249246
}
250247

251248
main_train(
252-
train_seq,
253-
valid_seq,
254249
train_data,
255250
valid_data,
256-
profiles,
257251
args.modelfile,
258252
args.device,
253+
args.project_name,
259254
config,
260255
args.use_wandb,
261256
)

0 commit comments

Comments
 (0)