99from torchvision import transforms
1010from tqdm import tqdm
1111
12- from utils import MetricWrapper , createfolders , load_data , load_model
12+ from utils import MetricWrapper , createfolders , load_data , load_model , get_args
1313
1414
1515def main ():
@@ -25,113 +25,23 @@ def main():
2525 ------
2626
2727 """
28- parser = argparse .ArgumentParser (
29- prog = "" ,
30- description = "" ,
31- epilog = "" ,
32- )
33- # Structuture related values
34- parser .add_argument (
35- "--datafolder" ,
36- type = Path ,
37- default = "Data" ,
38- help = "Path to where data will be saved during training." ,
39- )
40- parser .add_argument (
41- "--resultfolder" ,
42- type = Path ,
43- default = "Results" ,
44- help = "Path to where results will be saved during evaluation." ,
45- )
46- parser .add_argument (
47- "--modelfolder" ,
48- type = Path ,
49- default = "Experiments" ,
50- help = "Path to where model weights will be saved at the end of training." ,
51- )
52- parser .add_argument (
53- "--savemodel" ,
54- action = "store_true" ,
55- help = "Whether model should be saved or not." ,
56- )
57-
58- parser .add_argument (
59- "--download-data" ,
60- action = "store_true" ,
61- help = "Whether the data should be downloaded or not. Might cause code to start a bit slowly." ,
62- )
63-
64- # Data/Model specific values
65- parser .add_argument (
66- "--modelname" ,
67- type = str ,
68- default = "MagnusModel" ,
69- choices = ["MagnusModel" , "ChristianModel" , "SolveigModel" ],
70- help = "Model which to be trained on" ,
71- )
72- parser .add_argument (
73- "--dataset" ,
74- type = str ,
75- default = "svhn" ,
76- choices = ["svhn" , "usps_0-6" , "uspsh5_7_9" , "mnist_0-3" ],
77- help = "Which dataset to train the model on." ,
78- )
79-
80- parser .add_argument (
81- "--metric" ,
82- type = str ,
83- default = ["entropy" ],
84- choices = ["entropy" , "f1" , "recall" , "precision" , "accuracy" ],
85- nargs = "+" ,
86- help = "Which metric to use for evaluation" ,
87- )
88-
89- # Training specific values
90- parser .add_argument (
91- "--epoch" ,
92- type = int ,
93- default = 20 ,
94- help = "Amount of training epochs the model will do." ,
95- )
96- parser .add_argument (
97- "--learning_rate" ,
98- type = float ,
99- default = 0.001 ,
100- help = "Learning rate parameter for model training." ,
101- )
102- parser .add_argument (
103- "--batchsize" ,
104- type = int ,
105- default = 64 ,
106- help = "Amount of training images loaded in one go" ,
107- )
108- parser .add_argument (
109- "--device" ,
110- type = str ,
111- default = "cpu" ,
112- choices = ["cuda" , "cpu" , "mps" ],
113- help = "Which device to run the training on." ,
114- )
115- parser .add_argument (
116- "--dry_run" ,
117- action = "store_true" ,
118- help = "If true, the code will not run the training loop." ,
119- )
120-
121- args = parser .parse_args ()
28+ args = get_args ()
12229
12330 createfolders (args .datafolder , args .resultfolder , args .modelfolder )
12431
12532 device = args .device
12633
12734 metrics = MetricWrapper (* args .metric )
12835
129- augmentations = transforms .Compose (
130- [
131- transforms .Resize ((16 , 16 )), # At least for USPS
132- transforms .ToTensor (),
133- ]
134- )
36+ if args .dataset .lower () == "usps_0-6" or args .dataset .lower () == "uspsh5_7_9" :
37+ augmentations = transforms .Compose (
38+ [
39+ transforms .Resize ((16 , 16 )),
40+ transforms .ToTensor (),
41+ ]
42+ )
43+ else :
44+ augmentations = transforms .Compose ([transforms .ToTensor ()])
13545
13646 # Dataset
13747 traindata = load_data (
0 commit comments