1111
1212
1313def build_dataloaders (
14- device , train_seq , valid_seq , train_data , val_data , batch_size , ct_profiles
14+ device , train_data , valid_data , batch_size ,
1515):
1616 """
1717 Create training and validation data loaders.
@@ -24,21 +24,27 @@ def build_dataloaders(
2424 Tuple of DataLoader for training and validation datasets.
2525 """
2626 train_loader = DataLoader (
27- APAData (train_seq , train_data , ct_profiles , device ),
27+ APAData (train_data , device ),
2828 batch_size = batch_size ,
2929 shuffle = True ,
3030 drop_last = True ,
3131 )
3232 valid_loader = DataLoader (
33- APAData (valid_seq , val_data , ct_profiles , device ),
33+ APAData (valid_data , device ),
3434 batch_size = batch_size ,
3535 shuffle = False ,
3636 drop_last = False ,
3737 )
3838 return train_loader , valid_loader
3939
40+ def l1_penalty (model , l1_factor ):
41+ l1_reg = torch .tensor (0. ).to (model .device )
42+ for param in model .parameters ():
43+ l1_reg += torch .norm (param , 1 )
44+ return l1_factor * l1_reg
4045
41- def train_one_epoch (model , train_loader ):
46+
47+ def train_one_epoch (model , train_loader , l1_factor = 0.00005 ):
4248 """
4349 Train the model for one epoch.
4450 Args:
@@ -49,10 +55,13 @@ def train_one_epoch(model, train_loader):
4955 """
5056 model .train ()
5157 total_loss , predictions , targets = 0.0 , [], []
52- for seq_X , celltype , _ , Y in train_loader :
58+ for seq_X , Y , celltype , _ , _ in train_loader :
5359 model .optimizer .zero_grad ()
5460 outputs = torch .squeeze (model (seq_X , celltype ))
55- loss = torch .sqrt (model .loss_fn (outputs , Y ))
61+ mse_loss = torch .sqrt (model .loss_fn (outputs , Y ))
62+ # l1_loss = l1_penalty(model, l1_factor)
63+ # loss = mse_loss + l1_loss
64+ loss = mse_loss
5665 loss .backward ()
5766 model .optimizer .step ()
5867 total_loss += loss .item () * seq_X .size (0 )
@@ -77,7 +86,7 @@ def validate_one_epoch(model, valid_loader):
7786 model .eval ()
7887 total_loss , predictions , targets = 0.0 , [], []
7988 with torch .no_grad ():
80- for seq_X , celltype , _ , Y in valid_loader :
89+ for seq_X , Y , celltype , _ , _ in valid_loader :
8190 outputs = torch .squeeze (model (seq_X , celltype ))
8291 loss = torch .sqrt (model .loss_fn (outputs , Y ))
8392 total_loss += loss .item () * seq_X .size (0 )
@@ -92,13 +101,11 @@ def validate_one_epoch(model, valid_loader):
92101
93102
94103def main_train (
95- train_seq ,
96- valid_seq ,
97104 train_data ,
98105 val_data ,
99- profiles ,
100106 modelfile ,
101107 device ,
108+ project_name ,
102109 config ,
103110 use_wandb ,
104111):
@@ -115,18 +122,15 @@ def main_train(
115122 use_wandb = args .use_wandb .lower () == "true"
116123 train_loader , valid_loader = build_dataloaders (
117124 device ,
118- train_seq ,
119- valid_seq ,
120125 train_data ,
121126 val_data ,
122127 config ["batch_size" ],
123- profiles ,
124128 )
125129 with tqdm (range (config ["epochs" ]), unit = "epoch" ) as tepochs :
126130 if use_wandb :
127131 wandb .login ()
128132 with wandb .init (
129- project = config [ " project_name" ] ,
133+ project = project_name ,
130134 settings = wandb .Settings (start_method = "thread" ),
131135 ):
132136 model = APANET (config )
@@ -167,18 +171,9 @@ def main_train(
167171 parser .add_argument (
168172 "--train_data" , type = str , required = True , help = "Path to training data file"
169173 )
170- parser .add_argument (
171- "--train_seq" , type = str , required = True , help = "Path to training sequences file"
172- )
173174 parser .add_argument (
174175 "--valid_data" , type = str , required = True , help = "Path to validation data file"
175176 )
176- parser .add_argument (
177- "--valid_seq" , type = str , required = True , help = "Path to validation sequences file"
178- )
179- parser .add_argument (
180- "--profiles" , type = str , required = True , help = "Path to cell type profiles file"
181- )
182177 parser .add_argument (
183178 "--modelfile" , type = str , required = True , help = "Path to save the trained model"
184179 )
@@ -214,10 +209,7 @@ def main_train(
214209 np .random .seed (7 )
215210
216211 train_data = np .load (args .train_data , allow_pickle = True )
217- train_seq = np .load (args .train_seq , allow_pickle = True )
218212 valid_data = np .load (args .valid_data , allow_pickle = True )
219- valid_seq = np .load (args .valid_seq , allow_pickle = True )
220- profiles = pd .read_csv (args .profiles , index_col = 0 , sep = "\t " )
221213
222214 config = {
223215 "batch_size" : args .batch_size ,
@@ -227,35 +219,38 @@ def main_train(
227219 "opt" : "Adam" ,
228220 "loss" : "mse" ,
229221 "lr" : 2.5e-05 ,
230- "adam_weight_decay" : 0.06 ,
231- "conv1kc" : 128 ,
222+ "adam_weight_decay" : 0.09 , # 0.06 before
223+ "conv1kc" : 128 , #128, 64
232224 "conv1ks" : 12 ,
233225 "conv1st" : 1 ,
234- "pool1ks" : 25 ,
235- "pool1st" : 25 ,
236- "cnvpdrop1" : 0.2 ,
226+ "pool1ks" : 16 ,
227+ "pool1st" : 16 ,
228+ "cnvpdrop1" : 0 ,
237229 "Matt_heads" : 8 ,
238230 "Matt_drop" : 0.2 ,
239231 "fc1_dims" : [
240- 8192 ,
232+ 8192 , # 8192, 5120
241233 4048 ,
242234 1024 ,
243235 512 ,
244236 256 ,
245237 ], # first dimension will be calculated dynamically
246- "fc1_dropouts" : [0.3 , 0.25 , 0.25 , 0.2 , 0.1 ],
238+ "fc1_dropouts" : [0.25 , 0.25 , 0.25 , 0 , 0 ],
247239 "fc2_dims" : [128 , 32 , 16 , 1 ], # first dimension will be calculated dynamically
248240 "fc2_dropouts" : [0.2 , 0.2 , 0 , 0 ],
241+ 'psa_query_dim' : 128 , # make sure this is correct
242+ 'psa_num_layers' : 1 ,
243+ 'psa_nhead' : 1 ,
244+ 'psa_dim_feedforward' :1024 ,
245+ 'psa_dropout' : 0
249246 }
250247
251248 main_train (
252- train_seq ,
253- valid_seq ,
254249 train_data ,
255250 valid_data ,
256- profiles ,
257251 args .modelfile ,
258252 args .device ,
253+ args .project_name ,
259254 config ,
260255 args .use_wandb ,
261256 )
0 commit comments