11import numpy as np
22import torch as th
33import torch .nn as nn
4+ import wandb
45from torch .utils .data import DataLoader
56from torchvision import transforms
67from tqdm import tqdm
78
8- import wandb
99from utils import MetricWrapper , createfolders , get_args , load_data , load_model
1010
1111
@@ -22,13 +22,13 @@ def main():
2222 ------
2323
2424 """
25-
25+
2626 args = get_args ()
27-
27+
2828 createfolders (args .datafolder , args .resultfolder , args .modelfolder )
29-
29+
3030 device = args .device
31-
31+
3232 if args .dataset .lower () in ["usps_0-6" , "uspsh5_7_9" ]:
3333 augmentations = transforms .Compose (
3434 [
@@ -38,7 +38,7 @@ def main():
3838 )
3939 else :
4040 augmentations = transforms .Compose ([transforms .ToTensor ()])
41-
41+
4242 # Dataset
4343 traindata = load_data (
4444 args .dataset ,
@@ -54,22 +54,22 @@ def main():
5454 download = args .download_data ,
5555 transform = augmentations ,
5656 )
57-
57+
5858 metrics = MetricWrapper (traindata .num_classes , * args .metric )
59-
59+
6060 # Find the shape of the data, if is 2D, add a channel dimension
6161 data_shape = traindata [0 ][0 ].shape
6262 if len (data_shape ) == 2 :
6363 data_shape = (1 , * data_shape )
64-
64+
6565 # load model
6666 model = load_model (
6767 args .modelname ,
6868 image_shape = data_shape ,
6969 num_classes = traindata .num_classes ,
7070 )
7171 model .to (device )
72-
72+
7373 trainloader = DataLoader (
7474 traindata ,
7575 batch_size = args .batchsize ,
@@ -113,11 +113,11 @@ def main():
113113
114114 # wandb.login(key=WANDB_API)
115115 wandb .init (
116- entity = "ColabCode-org" ,
117- # entity="FYS-8805 Exam",
118- project = "Test" ,
119- tags = [args .modelname , args .dataset ]
120- )
116+ entity = "ColabCode-org" ,
117+ # entity="FYS-8805 Exam",
118+ project = "Test" ,
119+ tags = [args .modelname , args .dataset ],
120+ )
121121 wandb .watch (model )
122122 exit ()
123123 for epoch in range (args .epoch ):
0 commit comments