File tree Expand file tree Collapse file tree 2 files changed +208
-11
lines changed
Expand file tree Collapse file tree 2 files changed +208
-11
lines changed Original file line number Diff line number Diff line change 11import numpy as np
22import torch as th
33import torch .nn as nn
4- import wandb
54from torch .utils .data import DataLoader
65from torchvision import transforms
76from tqdm import tqdm
87
8+ import wandb
99from CollaborativeCoding import (
1010 MetricWrapper ,
1111 createfolders ,
@@ -132,6 +132,7 @@ def main():
132132 wandb .init (
133133 entity = "ColabCode" ,
134134 project = args .run_name ,
135+ dir = args .resultfolder ,
135136 tags = [args .modelname , args .dataset ],
136137 config = args ,
137138 )
@@ -178,6 +179,9 @@ def main():
178179 train_metrics .resetmetric ()
179180 val_metrics .resetmetric ()
180181
182+ if args .savemodel :
183+ th .save (model , args .modelfolder / f"{ args .modelname } _run:{ args .run_name } .pth" )
184+
181185 testloss = []
182186 model .eval ()
183187 with th .no_grad ():
You can’t perform that action at this time.
0 commit comments