Skip to content

Commit c25a2c8

Browse files
committed
moved argument parsing and handling from main to a separate file to clean up
1 parent ed0eaf2 commit c25a2c8

File tree

3 files changed

+111
-102
lines changed

3 files changed

+111
-102
lines changed

main.py

Lines changed: 11 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torchvision import transforms
1010
from tqdm import tqdm
1111

12-
from utils import MetricWrapper, createfolders, load_data, load_model
12+
from utils import MetricWrapper, createfolders, load_data, load_model, get_args
1313

1414

1515
def main():
@@ -25,113 +25,23 @@ def main():
2525
------
2626
2727
"""
28-
parser = argparse.ArgumentParser(
29-
prog="",
30-
description="",
31-
epilog="",
32-
)
33-
# Structuture related values
34-
parser.add_argument(
35-
"--datafolder",
36-
type=Path,
37-
default="Data",
38-
help="Path to where data will be saved during training.",
39-
)
40-
parser.add_argument(
41-
"--resultfolder",
42-
type=Path,
43-
default="Results",
44-
help="Path to where results will be saved during evaluation.",
45-
)
46-
parser.add_argument(
47-
"--modelfolder",
48-
type=Path,
49-
default="Experiments",
50-
help="Path to where model weights will be saved at the end of training.",
51-
)
52-
parser.add_argument(
53-
"--savemodel",
54-
action="store_true",
55-
help="Whether model should be saved or not.",
56-
)
57-
58-
parser.add_argument(
59-
"--download-data",
60-
action="store_true",
61-
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
62-
)
63-
64-
# Data/Model specific values
65-
parser.add_argument(
66-
"--modelname",
67-
type=str,
68-
default="MagnusModel",
69-
choices=["MagnusModel", "ChristianModel", "SolveigModel"],
70-
help="Model which to be trained on",
71-
)
72-
parser.add_argument(
73-
"--dataset",
74-
type=str,
75-
default="svhn",
76-
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
77-
help="Which dataset to train the model on.",
78-
)
79-
80-
parser.add_argument(
81-
"--metric",
82-
type=str,
83-
default=["entropy"],
84-
choices=["entropy", "f1", "recall", "precision", "accuracy"],
85-
nargs="+",
86-
help="Which metric to use for evaluation",
87-
)
88-
89-
# Training specific values
90-
parser.add_argument(
91-
"--epoch",
92-
type=int,
93-
default=20,
94-
help="Amount of training epochs the model will do.",
95-
)
96-
parser.add_argument(
97-
"--learning_rate",
98-
type=float,
99-
default=0.001,
100-
help="Learning rate parameter for model training.",
101-
)
102-
parser.add_argument(
103-
"--batchsize",
104-
type=int,
105-
default=64,
106-
help="Amount of training images loaded in one go",
107-
)
108-
parser.add_argument(
109-
"--device",
110-
type=str,
111-
default="cpu",
112-
choices=["cuda", "cpu", "mps"],
113-
help="Which device to run the training on.",
114-
)
115-
parser.add_argument(
116-
"--dry_run",
117-
action="store_true",
118-
help="If true, the code will not run the training loop.",
119-
)
120-
121-
args = parser.parse_args()
28+
args = get_args()
12229

12330
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
12431

12532
device = args.device
12633

12734
metrics = MetricWrapper(*args.metric)
12835

129-
augmentations = transforms.Compose(
130-
[
131-
transforms.Resize((16, 16)), # At least for USPS
132-
transforms.ToTensor(),
133-
]
134-
)
36+
if args.dataset.lower() == "usps_0-6" or args.dataset.lower() == "uspsh5_7_9":
37+
augmentations = transforms.Compose(
38+
[
39+
transforms.Resize((16, 16)),
40+
transforms.ToTensor(),
41+
]
42+
)
43+
else:
44+
augmentations = transforms.Compose([transforms.ToTensor()])
13545

13646
# Dataset
13747
traindata = load_data(

utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__all__ = ["createfolders", "load_data", "load_model", "MetricWrapper"]
1+
__all__ = ["createfolders", "load_data", "load_model", "MetricWrapper", "get_args"]
22

3+
from .arg_parser import get_args
34
from .createfolders import createfolders
45
from .load_data import load_data
56
from .load_metric import MetricWrapper

utils/arg_parser.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import argparse
2+
from pathlib import Path
3+
4+
5+
def get_args():
6+
parser = argparse.ArgumentParser(
7+
prog="",
8+
description="",
9+
epilog="",
10+
)
11+
# Structuture related values
12+
parser.add_argument(
13+
"--datafolder",
14+
type=Path,
15+
default="Data",
16+
help="Path to where data will be saved during training.",
17+
)
18+
parser.add_argument(
19+
"--resultfolder",
20+
type=Path,
21+
default="Results",
22+
help="Path to where results will be saved during evaluation.",
23+
)
24+
parser.add_argument(
25+
"--modelfolder",
26+
type=Path,
27+
default="Experiments",
28+
help="Path to where model weights will be saved at the end of training.",
29+
)
30+
parser.add_argument(
31+
"--savemodel",
32+
action="store_true",
33+
help="Whether model should be saved or not.",
34+
)
35+
36+
parser.add_argument(
37+
"--download-data",
38+
action="store_true",
39+
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
40+
)
41+
42+
# Data/Model specific values
43+
parser.add_argument(
44+
"--modelname",
45+
type=str,
46+
default="MagnusModel",
47+
choices=["MagnusModel", "ChristianModel", "SolveigModel", "JanModel"],
48+
help="Model which to be trained on",
49+
)
50+
parser.add_argument(
51+
"--dataset",
52+
type=str,
53+
default="svhn",
54+
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
55+
help="Which dataset to train the model on.",
56+
)
57+
58+
parser.add_argument(
59+
"--metric",
60+
type=str,
61+
default=["entropy"],
62+
choices=["entropy", "f1", "recall", "precision", "accuracy"],
63+
nargs="+",
64+
help="Which metric to use for evaluation",
65+
)
66+
67+
# Training specific values
68+
parser.add_argument(
69+
"--epoch",
70+
type=int,
71+
default=20,
72+
help="Amount of training epochs the model will do.",
73+
)
74+
parser.add_argument(
75+
"--learning_rate",
76+
type=float,
77+
default=0.001,
78+
help="Learning rate parameter for model training.",
79+
)
80+
parser.add_argument(
81+
"--batchsize",
82+
type=int,
83+
default=64,
84+
help="Amount of training images loaded in one go",
85+
)
86+
parser.add_argument(
87+
"--device",
88+
type=str,
89+
default="cpu",
90+
choices=["cuda", "cpu", "mps"],
91+
help="Which device to run the training on.",
92+
)
93+
parser.add_argument(
94+
"--dry_run",
95+
action="store_true",
96+
help="If true, the code will not run the training loop.",
97+
)
98+
return parser.parse_args()

0 commit comments

Comments
 (0)