Skip to content

Commit 891f09b

Browse files
authored
Merge pull request #40 from SFI-Visual-Intelligence/Jan-dev
Main.py updates
2 parents 2ff9ad3 + 2341c69 commit 891f09b

File tree

8 files changed

+159
-128
lines changed

8 files changed

+159
-128
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Results/
55
Experiments/
66
_build/
77
bin/
8+
wandb/
9+
wandb_api.py
810

911
# Byte-compiled / optimized / DLL files
1012
__pycache__/

main.py

Lines changed: 38 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import argparse
21
from pathlib import Path
32

43
import numpy as np
@@ -9,7 +8,7 @@
98
from torchvision import transforms
109
from tqdm import tqdm
1110

12-
from utils import MetricWrapper, createfolders, load_data, load_model
11+
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1312

1413

1514
def main():
@@ -25,113 +24,21 @@ def main():
2524
------
2625
2726
"""
28-
parser = argparse.ArgumentParser(
29-
prog="",
30-
description="",
31-
epilog="",
32-
)
33-
# Structuture related values
34-
parser.add_argument(
35-
"--datafolder",
36-
type=Path,
37-
default="Data",
38-
help="Path to where data will be saved during training.",
39-
)
40-
parser.add_argument(
41-
"--resultfolder",
42-
type=Path,
43-
default="Results",
44-
help="Path to where results will be saved during evaluation.",
45-
)
46-
parser.add_argument(
47-
"--modelfolder",
48-
type=Path,
49-
default="Experiments",
50-
help="Path to where model weights will be saved at the end of training.",
51-
)
52-
parser.add_argument(
53-
"--savemodel",
54-
action="store_true",
55-
help="Whether model should be saved or not.",
56-
)
57-
58-
parser.add_argument(
59-
"--download-data",
60-
action="store_true",
61-
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
62-
)
63-
64-
# Data/Model specific values
65-
parser.add_argument(
66-
"--modelname",
67-
type=str,
68-
default="MagnusModel",
69-
choices=["MagnusModel", "ChristianModel", "SolveigModel"],
70-
help="Model which to be trained on",
71-
)
72-
parser.add_argument(
73-
"--dataset",
74-
type=str,
75-
default="svhn",
76-
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
77-
help="Which dataset to train the model on.",
78-
)
79-
80-
parser.add_argument(
81-
"--metric",
82-
type=str,
83-
default=["entropy"],
84-
choices=["entropy", "f1", "recall", "precision", "accuracy"],
85-
nargs="+",
86-
help="Which metric to use for evaluation",
87-
)
88-
89-
# Training specific values
90-
parser.add_argument(
91-
"--epoch",
92-
type=int,
93-
default=20,
94-
help="Amount of training epochs the model will do.",
95-
)
96-
parser.add_argument(
97-
"--learning_rate",
98-
type=float,
99-
default=0.001,
100-
help="Learning rate parameter for model training.",
101-
)
102-
parser.add_argument(
103-
"--batchsize",
104-
type=int,
105-
default=64,
106-
help="Amount of training images loaded in one go",
107-
)
108-
parser.add_argument(
109-
"--device",
110-
type=str,
111-
default="cpu",
112-
choices=["cuda", "cpu", "mps"],
113-
help="Which device to run the training on.",
114-
)
115-
parser.add_argument(
116-
"--dry_run",
117-
action="store_true",
118-
help="If true, the code will not run the training loop.",
119-
)
120-
121-
args = parser.parse_args()
27+
args = get_args()
12228

12329
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
12430

12531
device = args.device
12632

127-
metrics = MetricWrapper(*args.metric)
128-
129-
augmentations = transforms.Compose(
130-
[
131-
transforms.Resize((16, 16)), # At least for USPS
132-
transforms.ToTensor(),
133-
]
134-
)
33+
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
34+
augmentations = transforms.Compose(
35+
[
36+
transforms.Resize((16, 16)),
37+
transforms.ToTensor(),
38+
]
39+
)
40+
else:
41+
augmentations = transforms.Compose([transforms.ToTensor()])
13542

13643
# Dataset
13744
traindata = load_data(
@@ -149,6 +56,8 @@ def main():
14956
transform=augmentations,
15057
)
15158

59+
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
60+
15261
# Find the shape of the data, if is 2D, add a channel dimension
15362
data_shape = traindata[0][0].shape
15463
if len(data_shape) == 2:
@@ -180,28 +89,32 @@ def main():
18089
if args.dry_run:
18190
dry_run_loader = DataLoader(
18291
traindata,
183-
batch_size=1,
92+
batch_size=20,
18493
shuffle=True,
18594
pin_memory=True,
18695
drop_last=True,
18796
)
18897

18998
for x, y in tqdm(dry_run_loader, desc="Dry run", total=1):
19099
x, y = x.to(device), y.to(device)
191-
pred = model.forward(x)
100+
logits = model.forward(x)
192101

193-
loss = criterion(y, pred)
102+
loss = criterion(logits, y)
194103
loss.backward()
195104

196105
optimizer.step()
197106
optimizer.zero_grad(set_to_none=True)
198107

199-
break
108+
preds = th.argmax(logits, dim=1)
109+
metrics(y, preds)
200110

111+
break
112+
print(metrics.__getmetrics__())
201113
print("Dry run completed successfully.")
202114
exit(0)
203115

204-
wandb.init(project="", tags=[])
116+
wandb.login(key=WANDB_API)
117+
wandb.init(entity="ColabCode", project="Jan", tags=[args.modelname, args.dataset])
205118
wandb.watch(model)
206119

207120
for epoch in range(args.epoch):
@@ -210,25 +123,37 @@ def main():
210123
model.train()
211124
for x, y in tqdm(trainloader, desc="Training"):
212125
x, y = x.to(device), y.to(device)
213-
pred = model.forward(x)
126+
logits = model.forward(x)
214127

215-
loss = criterion(y, pred)
128+
loss = criterion(logits, y)
216129
loss.backward()
217130

218131
optimizer.step()
219132
optimizer.zero_grad(set_to_none=True)
220133
trainingloss.append(loss.item())
221134

135+
preds = th.argmax(logits, dim=1)
136+
metrics(y, preds)
137+
138+
wandb.log(metrics.__getmetrics__(str_prefix="Train "))
139+
metrics.__resetvalues__()
140+
222141
evalloss = []
223142
# Eval loop start
224143
model.eval()
225144
with th.no_grad():
226145
for x, y in tqdm(valiloader, desc="Validation"):
227146
x, y = x.to(device), y.to(device)
228-
pred = model.forward(x)
229-
loss = criterion(y, pred)
147+
logits = model.forward(x)
148+
loss = criterion(logits, y)
230149
evalloss.append(loss.item())
231150

151+
preds = th.argmax(logits, dim=1)
152+
metrics(y, preds)
153+
154+
wandb.log(metrics.__getmetrics__(str_prefix="Evaluation "))
155+
metrics.__resetvalues__()
156+
232157
wandb.log(
233158
{
234159
"Epoch": epoch,

tests/test_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,13 @@ def test_for_zero_denominator():
8787
def test_accuracy():
8888
import torch
8989

90-
accuracy = Accuracy()
90+
accuracy = Accuracy(num_classes=5)
9191

9292
y_true = torch.tensor([0, 3, 2, 3, 4])
9393
y_pred = torch.tensor([0, 1, 2, 3, 4])
9494

9595
accuracy_score = accuracy(y_true, y_pred)
9696

97-
assert (torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5), (
97+
assert torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5, (
9898
f"Accuracy Score: {accuracy_score.item()}"
9999
)

utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__all__ = ["createfolders", "load_data", "load_model", "MetricWrapper"]
1+
__all__ = ["createfolders", "load_data", "load_model", "MetricWrapper", "get_args"]
22

3+
from .arg_parser import get_args
34
from .createfolders import createfolders
45
from .load_data import load_data
56
from .load_metric import MetricWrapper

utils/arg_parser.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import argparse
2+
from pathlib import Path
3+
4+
5+
def get_args():
6+
parser = argparse.ArgumentParser(
7+
prog="",
8+
description="",
9+
epilog="",
10+
)
11+
# Structuture related values
12+
parser.add_argument(
13+
"--datafolder",
14+
type=Path,
15+
default="Data",
16+
help="Path to where data will be saved during training.",
17+
)
18+
parser.add_argument(
19+
"--resultfolder",
20+
type=Path,
21+
default="Results",
22+
help="Path to where results will be saved during evaluation.",
23+
)
24+
parser.add_argument(
25+
"--modelfolder",
26+
type=Path,
27+
default="Experiments",
28+
help="Path to where model weights will be saved at the end of training.",
29+
)
30+
parser.add_argument(
31+
"--savemodel",
32+
action="store_true",
33+
help="Whether model should be saved or not.",
34+
)
35+
36+
parser.add_argument(
37+
"--download-data",
38+
action="store_true",
39+
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
40+
)
41+
42+
# Data/Model specific values
43+
parser.add_argument(
44+
"--modelname",
45+
type=str,
46+
default="MagnusModel",
47+
choices=["MagnusModel", "ChristianModel", "SolveigModel", "JanModel"],
48+
help="Model which to be trained on",
49+
)
50+
parser.add_argument(
51+
"--dataset",
52+
type=str,
53+
default="svhn",
54+
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
55+
help="Which dataset to train the model on.",
56+
)
57+
58+
parser.add_argument(
59+
"--metric",
60+
type=str,
61+
default=["entropy"],
62+
choices=["entropy", "f1", "recall", "precision", "accuracy"],
63+
nargs="+",
64+
help="Which metric to use for evaluation",
65+
)
66+
67+
# Training specific values
68+
parser.add_argument(
69+
"--epoch",
70+
type=int,
71+
default=20,
72+
help="Amount of training epochs the model will do.",
73+
)
74+
parser.add_argument(
75+
"--learning_rate",
76+
type=float,
77+
default=0.001,
78+
help="Learning rate parameter for model training.",
79+
)
80+
parser.add_argument(
81+
"--batchsize",
82+
type=int,
83+
default=64,
84+
help="Amount of training images loaded in one go",
85+
)
86+
parser.add_argument(
87+
"--device",
88+
type=str,
89+
default="cpu",
90+
choices=["cuda", "cpu", "mps"],
91+
help="Which device to run the training on.",
92+
)
93+
parser.add_argument(
94+
"--dry_run",
95+
action="store_true",
96+
help="If true, the code will not run the training loop.",
97+
)
98+
return parser.parse_args()

utils/dataloaders/mnist_0_3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,11 @@ def __len__(self):
134134

135135
def __getitem__(self, index):
136136
with open(self.labels_path, "rb") as f:
137-
f.seek(8 + index) # Jump to the label position
137+
f.seek(8 + self.idx[index]) # Jump to the label position
138138
label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label
139139

140140
with open(self.images_path, "rb") as f:
141-
f.seek(16 + index * 28 * 28) # Jump to image position
141+
f.seek(16 + self.idx[index] * 28 * 28) # Jump to image position
142142
image = np.frombuffer(f.read(28 * 28), dtype=np.uint8).reshape(
143143
28, 28
144144
) # Read image data

0 commit comments

Comments
 (0)