File tree Expand file tree Collapse file tree 1 file changed +28
-11
lines changed
Expand file tree Collapse file tree 1 file changed +28
-11
lines changed Original file line number Diff line number Diff line change 1111
1212
1313def main ():
14- '''
15-
14+ """
15+
1616 Parameters
1717 ----------
18-
18+
1919 Returns
2020 -------
21-
21+
2222 Raises
2323 ------
24-
25- '''
24+
25+ """
2626 parser = argparse .ArgumentParser (
27- prog = '' ,
28- description = '' ,
29- epilog = '' ,
27+ prog = "" ,
28+ description = "" ,
29+ epilog = "" ,
3030 )
3131 # Structuture related values
3232 parser .add_argument (
@@ -105,15 +105,27 @@ def main():
105105 default = 64 ,
106106 help = "Amount of training images loaded in one go" ,
107107 )
108+ parser .add_argument (
109+ "--device" ,
110+ type = str ,
111+ default = "cuda" ,
112+ choices = ["cuda" , "cpu" , "mps" ],
113+ help = "Which device to run the training on." ,
114+ )
115+ parser .add_argument (
116+ "--dry_run" ,
117+ action = "store_true" ,
118+ help = "If true, the code will not run the training loop." ,
119+ )
108120
109121 args = parser .parse_args ()
110122
111123 createfolders (args .datafolder , args .resultfolder , args .modelfolder )
112124
113- device = 'cuda' if th . cuda . is_available () else 'cpu'
125+ device = args . device
114126
115127 # load model
116- model = load_model ()
128+ model = load_model (args . modelname )
117129 model .to (device )
118130
119131 metrics = MetricWrapper (* args .metric )
@@ -144,6 +156,11 @@ def main():
144156 criterion = nn .CrossEntropyLoss ()
145157 optimizer = th .optim .Adam (model .parameters (), lr = args .learning_rate )
146158
159+ # This allows us to load all the components without running the training loop
160+ if args .dry_run :
161+ print ("Dry run completed" )
162+ exit (0 )
163+
147164 wandb .init (project = '' ,
148165 tags = [])
149166 wandb .watch (model )
You can’t perform that action at this time.
0 commit comments