Skip to content

Commit 1947e82

Browse files
committed
Format using ruff
1 parent 6ad365c commit 1947e82

File tree

1 file changed

+102
-46
lines changed

1 file changed

+102
-46
lines changed

main.py

Lines changed: 102 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
import torch as th
2-
import torch.nn as nn
3-
from torch.utils.data import DataLoader
41
import argparse
5-
import wandb
2+
63
import numpy as np
7-
from utils import MetricWrapper, load_model, load_data, createfolders
4+
import torch as th
5+
import torch.nn as nn
6+
import wandb
7+
from torch.utils.data import DataLoader
8+
9+
from utils import MetricWrapper, createfolders, load_data, load_model
810

911

1012
def main():
@@ -25,44 +27,100 @@ def main():
2527
description='',
2628
epilog='',
2729
)
28-
#Structuture related values
29-
parser.add_argument('--datafolder', type=str, default='Data/', help='Path to where data will be saved during training.')
30-
parser.add_argument('--resultfolder', type=str, default='Results/', help='Path to where results will be saved during evaluation.')
31-
parser.add_argument('--modelfolder', type=str, default='Experiments/', help='Path to where model weights will be saved at the end of training.')
32-
parser.add_argument('--savemodel', type=bool, default=False, help='Whether model should be saved or not.')
33-
34-
parser.add_argument('--download-data', type=bool, default=False, help='Whether the data should be downloaded or not. Might cause code to start a bit slowly.')
35-
36-
#Data/Model specific values
37-
parser.add_argument('--modelname', type=str, default='MagnusModel',
38-
choices = ['MagnusModel'], help="Model which to be trained on")
39-
parser.add_argument('--dataset', type=str, default='svhn',
40-
choices=['svhn'], help='Which dataset to train the model on.')
41-
42-
parser.add_argument("--metric", type=str, default="entropy", choices=['entropy', 'f1', 'recall', 'precision', 'accuracy'], nargs="+", help='Which metric to use for evaluation')
30+
# Structuture related values
31+
parser.add_argument(
32+
"--datafolder",
33+
type=str,
34+
default="Data/",
35+
help="Path to where data will be saved during training.",
36+
)
37+
parser.add_argument(
38+
"--resultfolder",
39+
type=str,
40+
default="Results/",
41+
help="Path to where results will be saved during evaluation.",
42+
)
43+
parser.add_argument(
44+
"--modelfolder",
45+
type=str,
46+
default="Experiments/",
47+
help="Path to where model weights will be saved at the end of training.",
48+
)
49+
parser.add_argument(
50+
"--savemodel",
51+
type=bool,
52+
default=False,
53+
help="Whether model should be saved or not.",
54+
)
55+
56+
parser.add_argument(
57+
"--download-data",
58+
type=bool,
59+
default=False,
60+
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
61+
)
62+
63+
# Data/Model specific values
64+
parser.add_argument(
65+
"--modelname",
66+
type=str,
67+
default="MagnusModel",
68+
choices=["MagnusModel"],
69+
help="Model which to be trained on",
70+
)
71+
parser.add_argument(
72+
"--dataset",
73+
type=str,
74+
default="svhn",
75+
choices=["svhn"],
76+
help="Which dataset to train the model on.",
77+
)
78+
79+
parser.add_argument(
80+
"--metric",
81+
type=str,
82+
default="entropy",
83+
choices=["entropy", "f1", "recall", "precision", "accuracy"],
84+
nargs="+",
85+
help="Which metric to use for evaluation",
86+
)
87+
88+
# Training specific values
89+
parser.add_argument(
90+
"--epoch",
91+
type=int,
92+
default=20,
93+
help="Amount of training epochs the model will do.",
94+
)
95+
parser.add_argument(
96+
"--learning_rate",
97+
type=float,
98+
default=0.001,
99+
help="Learning rate parameter for model training.",
100+
)
101+
parser.add_argument(
102+
"--batchsize",
103+
type=int,
104+
default=64,
105+
help="Amount of training images loaded in one go",
106+
)
43107

44-
#Training specific values
45-
parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do.')
46-
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate parameter for model training.')
47-
parser.add_argument('--batchsize', type=int, default=64, help='Amount of training images loaded in one go')
48-
49108
args = parser.parse_args()
50-
51109

52110
createfolders(args)
53-
111+
54112
device = 'cuda' if th.cuda.is_available() else 'cpu'
55-
56-
#load model
113+
114+
# load model
57115
model = load_model()
58116
model.to(device)
59-
117+
60118
metrics = MetricWrapper(*args.metric)
61-
62-
#Dataset
119+
120+
# Dataset
63121
traindata = load_data(args.dataset)
64122
validata = load_data(args.dataset)
65-
123+
66124
trainloader = DataLoader(traindata,
67125
batch_size=args.batchsize,
68126
shuffle=True,
@@ -72,47 +130,45 @@ def main():
72130
batch_size=args.batchsize,
73131
shuffle=False,
74132
pin_memory=True)
75-
133+
76134
criterion = nn.CrossEntropyLoss()
77-
optimizer = th.optim.Adam(model.parameters(), lr = args.learning_rate)
78-
79-
135+
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
136+
80137
wandb.init(project='',
81138
tags=[])
82139
wandb.watch(model)
83-
140+
84141
for epoch in range(args.epoch):
85-
86-
#Training loop start
142+
143+
# Training loop start
87144
trainingloss = []
88145
model.train()
89146
for x, y in traindata:
90147
x, y = x.to(device), y.to(device)
91148
pred = model.forward(x)
92-
149+
93150
loss = criterion(y, pred)
94151
loss.backward()
95-
152+
96153
optimizer.step()
97154
optimizer.zero_grad(set_to_none=True)
98155
trainingloss.append(loss.item())
99-
156+
100157
evalloss = []
101-
#Eval loop start
158+
# Eval loop start
102159
model.eval()
103160
with th.no_grad():
104161
for x, y in valiloader:
105162
x = x.to(device)
106163
pred = model.forward(x)
107164
loss = criterion(y, pred)
108165
evalloss.append(loss.item())
109-
166+
110167
wandb.log({
111168
'Epoch': epoch,
112169
'Train loss': np.mean(trainingloss),
113170
'Evaluation Loss': np.mean(evalloss)
114171
})
115-
116172

117173

118174
if __name__ == '__main__':

0 commit comments

Comments
 (0)