Skip to content

Commit ed1f1ee

Browse files
authored
Merge pull request #21 from SFI-Visual-Intelligence/christian/dataloader
- Implemented USPS dataloader for integers 0-6. - Fixed a few bugs related to relative imports (changing `from metrics import ...`to `from .metrics import ...`) - Changed from using `os.path` to the superior `pathlib`. - Added `--dry_run` option to cli interface. - Added `mps` backend to `--device`.
2 parents 53d23e3 + d045a2a commit ed1f1ee

File tree

11 files changed

+325
-93
lines changed

11 files changed

+325
-93
lines changed

environment.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ dependencies:
99
- sphinx-autobuild
1010
- sphinx-rtd-theme
1111
- pip
12+
- h5py
13+
- black
14+
- isort
15+
- jupyterlab
16+
- numpy
17+
- pandas
1218
- pytest
19+
- ruff
20+
- scalene
1321
prefix: /opt/miniconda3/envs/cc-exam
1422

main.py

Lines changed: 147 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,148 @@
1-
import torch as th
2-
import torch.nn as nn
3-
from torch.utils.data import DataLoader
41
import argparse
5-
import wandb
2+
from pathlib import Path
3+
64
import numpy as np
7-
from utils import MetricWrapper, load_model, load_data, createfolders
5+
import torch as th
6+
import torch.nn as nn
7+
import wandb
8+
from torch.utils.data import DataLoader
9+
10+
from utils import MetricWrapper, createfolders, load_data, load_model
811

912

1013
def main():
11-
'''
12-
14+
"""
15+
1316
Parameters
1417
----------
15-
18+
1619
Returns
1720
-------
18-
21+
1922
Raises
2023
------
21-
22-
'''
24+
25+
"""
2326
parser = argparse.ArgumentParser(
24-
prog='',
25-
description='',
26-
epilog='',
27-
)
28-
#Structuture related values
29-
parser.add_argument('--datafolder', type=str, default='Data/', help='Path to where data will be saved during training.')
30-
parser.add_argument('--resultfolder', type=str, default='Results/', help='Path to where results will be saved during evaluation.')
31-
parser.add_argument('--modelfolder', type=str, default='Experiments/', help='Path to where model weights will be saved at the end of training.')
32-
parser.add_argument('--savemodel', type=bool, default=False, help='Whether model should be saved or not.')
33-
34-
parser.add_argument('--download-data', type=bool, default=False, help='Whether the data should be downloaded or not. Might cause code to start a bit slowly.')
35-
36-
#Data/Model specific values
37-
parser.add_argument('--modelname', type=str, default='MagnusModel',
38-
choices = ['MagnusModel'], help="Model which to be trained on")
39-
parser.add_argument('--dataset', type=str, default='svhn',
40-
choices=['svhn'], help='Which dataset to train the model on.')
41-
42-
parser.add_argument("--metric", type=str, default="entropy", choices=['entropy', 'f1', 'recall', 'precision', 'accuracy'], nargs="+", help='Which metric to use for evaluation')
43-
44-
#Training specific values
45-
parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do.')
46-
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate parameter for model training.')
47-
parser.add_argument('--batchsize', type=int, default=64, help='Amount of training images loaded in one go')
48-
27+
prog="",
28+
description="",
29+
epilog="",
30+
)
31+
# Structuture related values
32+
parser.add_argument(
33+
"--datafolder",
34+
type=Path,
35+
default="Data",
36+
help="Path to where data will be saved during training.",
37+
)
38+
parser.add_argument(
39+
"--resultfolder",
40+
type=Path,
41+
default="Results",
42+
help="Path to where results will be saved during evaluation.",
43+
)
44+
parser.add_argument(
45+
"--modelfolder",
46+
type=Path,
47+
default="Experiments",
48+
help="Path to where model weights will be saved at the end of training.",
49+
)
50+
parser.add_argument(
51+
"--savemodel",
52+
type=bool,
53+
default=False,
54+
help="Whether model should be saved or not.",
55+
)
56+
57+
parser.add_argument(
58+
"--download-data",
59+
type=bool,
60+
default=False,
61+
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
62+
)
63+
64+
# Data/Model specific values
65+
parser.add_argument(
66+
"--modelname",
67+
type=str,
68+
default="MagnusModel",
69+
choices=["MagnusModel"],
70+
help="Model which to be trained on",
71+
)
72+
parser.add_argument(
73+
"--dataset",
74+
type=str,
75+
default="svhn",
76+
choices=["svhn", "usps_0-6"],
77+
help="Which dataset to train the model on.",
78+
)
79+
80+
parser.add_argument(
81+
"--metric",
82+
type=str,
83+
default=["entropy"],
84+
choices=["entropy", "f1", "recall", "precision", "accuracy"],
85+
nargs="+",
86+
help="Which metric to use for evaluation",
87+
)
88+
89+
# Training specific values
90+
parser.add_argument(
91+
"--epoch",
92+
type=int,
93+
default=20,
94+
help="Amount of training epochs the model will do.",
95+
)
96+
parser.add_argument(
97+
"--learning_rate",
98+
type=float,
99+
default=0.001,
100+
help="Learning rate parameter for model training.",
101+
)
102+
parser.add_argument(
103+
"--batchsize",
104+
type=int,
105+
default=64,
106+
help="Amount of training images loaded in one go",
107+
)
108+
parser.add_argument(
109+
"--device",
110+
type=str,
111+
default="cuda",
112+
choices=["cuda", "cpu", "mps"],
113+
help="Which device to run the training on.",
114+
)
115+
parser.add_argument(
116+
"--dry_run",
117+
action="store_true",
118+
help="If true, the code will not run the training loop.",
119+
)
120+
49121
args = parser.parse_args()
50-
51-
52-
createfolders(args)
53-
54-
device = 'cuda' if th.cuda.is_available() else 'cpu'
55-
56-
#load model
57-
model = load_model()
122+
123+
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
124+
125+
device = args.device
126+
127+
# load model
128+
model = load_model(args.modelname)
58129
model.to(device)
59-
130+
60131
metrics = MetricWrapper(*args.metric)
61-
62-
#Dataset
63-
traindata = load_data(args.dataset)
64-
validata = load_data(args.dataset)
65-
132+
133+
# Dataset
134+
traindata = load_data(
135+
args.dataset,
136+
train=True,
137+
data_path=args.datafolder,
138+
download=args.download_data,
139+
)
140+
validata = load_data(
141+
args.dataset,
142+
train=False,
143+
data_path=args.datafolder,
144+
)
145+
66146
trainloader = DataLoader(traindata,
67147
batch_size=args.batchsize,
68148
shuffle=True,
@@ -72,47 +152,50 @@ def main():
72152
batch_size=args.batchsize,
73153
shuffle=False,
74154
pin_memory=True)
75-
155+
76156
criterion = nn.CrossEntropyLoss()
77-
optimizer = th.optim.Adam(model.parameters(), lr = args.learning_rate)
78-
79-
157+
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
158+
159+
# This allows us to load all the components without running the training loop
160+
if args.dry_run:
161+
print("Dry run completed")
162+
exit(0)
163+
80164
wandb.init(project='',
81165
tags=[])
82166
wandb.watch(model)
83-
167+
84168
for epoch in range(args.epoch):
85-
86-
#Training loop start
169+
170+
# Training loop start
87171
trainingloss = []
88172
model.train()
89-
for x, y in traindata:
173+
for x, y in trainloader:
90174
x, y = x.to(device), y.to(device)
91175
pred = model.forward(x)
92-
176+
93177
loss = criterion(y, pred)
94178
loss.backward()
95-
179+
96180
optimizer.step()
97181
optimizer.zero_grad(set_to_none=True)
98182
trainingloss.append(loss.item())
99-
183+
100184
evalloss = []
101-
#Eval loop start
185+
# Eval loop start
102186
model.eval()
103187
with th.no_grad():
104188
for x, y in valiloader:
105189
x = x.to(device)
106190
pred = model.forward(x)
107191
loss = criterion(y, pred)
108192
evalloss.append(loss.item())
109-
193+
110194
wandb.log({
111195
'Epoch': epoch,
112196
'Train loss': np.mean(trainingloss),
113197
'Evaluation Loss': np.mean(evalloss)
114198
})
115-
116199

117200

118201
if __name__ == '__main__':

utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
__all__ = ['createfolders', 'load_data', 'load_model', 'MetricWrapper']
2+
13
from .createfolders import createfolders
24
from .load_data import load_data
35
from .load_metric import MetricWrapper

utils/createfolders.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import argparse
2-
import os
2+
from pathlib import Path
33
from tempfile import TemporaryDirectory
44

55

6-
def createfolders(args) -> None:
6+
def createfolders(*dirs: Path) -> None:
77
"""
88
Creates folders for storing data, results, model weights.
99
@@ -14,43 +14,44 @@ def createfolders(args) -> None:
1414
1515
"""
1616

17-
if not os.path.exists(args.datafolder):
18-
os.makedirs(args.datafolder)
19-
print(f"Created a folder at {args.datafolder}")
20-
21-
if not os.path.exists(args.resultfolder):
22-
os.makedirs(args.resultfolder)
23-
print(f"Created a folder at {args.resultfolder}")
24-
25-
if not os.path.exists(args.modelfolder):
26-
os.makedirs(args.modelfolder)
27-
print(f"Created a folder at {args.modelfolder}")
17+
for dir in dirs:
18+
dir.mkdir(parents=True, exist_ok=True)
2819

2920

3021
def test_createfolders():
31-
with TemporaryDirectory(dir="tmp/") as temp_dir:
22+
with TemporaryDirectory() as temp_dir:
23+
temp_dir = Path(temp_dir)
24+
3225
parser = argparse.ArgumentParser()
26+
3327
# Structuture related values
3428
parser.add_argument(
3529
"--datafolder",
36-
type=str,
37-
default=os.path.join(temp_dir, "Data/"),
30+
type=Path,
31+
default=temp_dir / "Data",
3832
help="Path to where data will be saved during training.",
3933
)
4034
parser.add_argument(
4135
"--resultfolder",
42-
type=str,
43-
default=os.path.join(temp_dir, "Results/"),
36+
type=Path,
37+
default=temp_dir / "Results",
4438
help="Path to where results will be saved during evaluation.",
4539
)
4640
parser.add_argument(
4741
"--modelfolder",
48-
type=str,
49-
default=os.path.join(temp_dir, "Experiments/"),
42+
type=Path,
43+
default=temp_dir / "Experiments",
5044
help="Path to where model weights will be saved at the end of training.",
5145
)
5246

53-
args = parser.parse_args()
54-
createfolders(args)
47+
args = parser.parse_args([
48+
"--datafolder", temp_dir / "Data",
49+
"--resultfolder", temp_dir / "Results",
50+
"--modelfolder", temp_dir / "Experiments"
51+
])
52+
53+
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
5554

56-
return
55+
assert (temp_dir / "Data").exists()
56+
assert (temp_dir / "Results").exists()
57+
assert (temp_dir / "Experiments").exists()

utils/dataloaders/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__all__ = ["USPSDataset0_6"]
2+
3+
from .usps_0_6 import USPSDataset0_6

0 commit comments

Comments
 (0)