Skip to content

Commit 84e6ca8

Browse files
authored
Merge pull request #51 from SFI-Visual-Intelligence/main
Sync
2 parents 5ff4aaf + b93ee66 commit 84e6ca8

32 files changed

+1276
-336
lines changed

.github/workflows/test.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: Test
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- name: Checkout
15+
uses: actions/checkout@v4
16+
17+
- uses: mamba-org/setup-micromamba@v1
18+
with:
19+
micromamba-version: '2.0.5-0' # any version from https://github.com/mamba-org/micromamba-releases
20+
environment-file: environment.yml
21+
init-shell: bash
22+
cache-environment: true
23+
post-cleanup: 'all'
24+
generate-run-shell: false
25+
26+
- name: Run tests
27+
run: |
28+
PYTHONPATH=. pytest tests
29+
shell: bash -el {0}

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Results/
55
Experiments/
66
_build/
77
bin/
8+
wandb/
9+
wandb_api.py
810

911
#Magnus specific
1012
docker/*

doc/about.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# About this code
22

3-
Work in progress ...
3+
Work is still in progress ...

doc/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77

88
extensions = [
99
"myst_parser", # in order to use markdown
10+
"autoapi.extension", # in order to generate API documentation
1011
]
1112

13+
# search this directory for Python files
14+
autoapi_dirs = ["../utils"]
15+
1216
myst_enable_extensions = [
1317
"colon_fence", # ::: can be used instead of ``` for better rendering
1418
]

environment.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,9 @@ dependencies:
1818
- pytest
1919
- ruff
2020
- scalene
21+
- tqdm
22+
- pip:
23+
- torch
24+
- torchvision
2125
prefix: /opt/miniconda3/envs/cc-exam
2226

main.py

Lines changed: 85 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
import argparse
21
from pathlib import Path
32

43
import numpy as np
54
import torch as th
65
import torch.nn as nn
76
import wandb
87
from torch.utils.data import DataLoader
8+
from torchvision import transforms
9+
from tqdm import tqdm
910

10-
from utils import MetricWrapper, createfolders, load_data, load_model
11+
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1112

1213

1314
def main():
@@ -23,202 +24,146 @@ def main():
2324
------
2425
2526
"""
26-
parser = argparse.ArgumentParser(
27-
prog="",
28-
description="",
29-
epilog="",
30-
)
31-
# Structuture related values
32-
parser.add_argument(
33-
"--datafolder",
34-
type=Path,
35-
default="Data",
36-
help="Path to where data will be saved during training.",
37-
)
38-
parser.add_argument(
39-
"--resultfolder",
40-
type=Path,
41-
default="Results",
42-
help="Path to where results will be saved during evaluation.",
43-
)
44-
parser.add_argument(
45-
"--modelfolder",
46-
type=Path,
47-
default="Experiments",
48-
help="Path to where model weights will be saved at the end of training.",
49-
)
50-
parser.add_argument(
51-
"--savemodel",
52-
type=bool,
53-
default=False,
54-
help="Whether model should be saved or not.",
55-
)
56-
57-
parser.add_argument(
58-
"--download-data",
59-
type=bool,
60-
default=False,
61-
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
62-
)
63-
64-
# Data/Model specific values
65-
parser.add_argument(
66-
"--modelname",
67-
type=str,
68-
default="MagnusModel",
69-
choices=["MagnusModel", "ChristianModel"],
70-
help="Model which to be trained on",
71-
)
72-
parser.add_argument(
73-
"--dataset",
74-
type=str,
75-
default="svhn",
76-
choices=["svhn", "usps_0-6"],
77-
help="Which dataset to train the model on.",
78-
)
7927

80-
parser.add_argument(
81-
"--metric",
82-
type=str,
83-
default="entropy",
84-
choices=["entropy", "f1", "recall", "precision", "accuracy"],
85-
nargs="+",
86-
help="Which metric to use for evaluation",
87-
)
88-
parser.add_argument('--imagesize',
89-
type=int,
90-
default=28,
91-
help='Size of images')
92-
parser.add_argument('--imagechannels',
93-
type=int,
94-
default=1,
95-
choices=[1,3],
96-
help='Number of color channels in the image.')
97-
98-
99-
100-
101-
# Training specific values
102-
parser.add_argument(
103-
"--epoch",
104-
type=int,
105-
default=20,
106-
help="Amount of training epochs the model will do.",
107-
)
108-
parser.add_argument(
109-
"--learning_rate",
110-
type=float,
111-
default=0.001,
112-
help="Learning rate parameter for model training.",
113-
)
114-
parser.add_argument(
115-
"--batchsize",
116-
type=int,
117-
default=64,
118-
help="Amount of training images loaded in one go",
119-
)
120-
parser.add_argument(
121-
"--device",
122-
type=str,
123-
default="cpu",
124-
choices=["cuda", "cpu", "mps"],
125-
help="Which device to run the training on.",
126-
)
127-
parser.add_argument(
128-
"--dry_run",
129-
action="store_true",
130-
help="If true, the code will not run the training loop.",
131-
)
28+
args = get_args()
13229

133-
args = parser.parse_args()
13430

13531
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
13632

13733
device = args.device
13834

139-
metrics = MetricWrapper(*args.metric)
35+
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
36+
augmentations = transforms.Compose(
37+
[
38+
transforms.Resize((16, 16)),
39+
transforms.ToTensor(),
40+
]
41+
)
42+
else:
43+
augmentations = transforms.Compose([transforms.ToTensor()])
14044

14145
# Dataset
14246
traindata = load_data(
14347
args.dataset,
14448
train=True,
14549
data_path=args.datafolder,
14650
download=args.download_data,
51+
transform=augmentations,
14752
)
14853
validata = load_data(
14954
args.dataset,
15055
train=False,
15156
data_path=args.datafolder,
57+
download=args.download_data,
58+
transform=augmentations,
15259
)
15360

154-
# Find number of channels in the dataset
155-
if len(traindata[0][0].shape) == 2:
156-
channels = 1
157-
else:
158-
channels = traindata[0][0].shape[0]
61+
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
62+
63+
# Find the shape of the data, if is 2D, add a channel dimension
64+
data_shape = traindata[0][0].shape
65+
if len(data_shape) == 2:
66+
data_shape = (1, *data_shape)
15967

16068
# load model
16169
model = load_model(
16270
args.modelname,
163-
in_channels=channels,
71+
image_shape=data_shape,
16472
num_classes=traindata.num_classes,
16573
)
16674
model.to(device)
16775

168-
trainloader = DataLoader(traindata,
169-
batch_size=args.batchsize,
170-
shuffle=True,
171-
pin_memory=True,
172-
drop_last=True)
173-
valiloader = DataLoader(validata,
174-
batch_size=args.batchsize,
175-
shuffle=False,
176-
pin_memory=True)
76+
trainloader = DataLoader(
77+
traindata,
78+
batch_size=args.batchsize,
79+
shuffle=True,
80+
pin_memory=True,
81+
drop_last=True,
82+
)
83+
valiloader = DataLoader(
84+
validata, batch_size=args.batchsize, shuffle=False, pin_memory=True
85+
)
17786

17887
criterion = nn.CrossEntropyLoss()
17988
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
18089

18190
# This allows us to load all the components without running the training loop
18291
if args.dry_run:
183-
print("Dry run completed")
92+
dry_run_loader = DataLoader(
93+
traindata,
94+
batch_size=20,
95+
shuffle=True,
96+
pin_memory=True,
97+
drop_last=True,
98+
)
99+
100+
for x, y in tqdm(dry_run_loader, desc="Dry run", total=1):
101+
x, y = x.to(device), y.to(device)
102+
logits = model.forward(x)
103+
104+
loss = criterion(logits, y)
105+
loss.backward()
106+
107+
optimizer.step()
108+
optimizer.zero_grad(set_to_none=True)
109+
110+
preds = th.argmax(logits, dim=1)
111+
metrics(y, preds)
112+
113+
break
114+
print(metrics.accumulate())
115+
print("Dry run completed successfully.")
184116
exit(0)
185117

186-
wandb.init(project='',
187-
tags=[])
118+
wandb.login(key=WANDB_API)
119+
wandb.init(entity="ColabCode", project="Jan", tags=[args.modelname, args.dataset])
188120
wandb.watch(model)
189121

190122
for epoch in range(args.epoch):
191-
192123
# Training loop start
193124
trainingloss = []
194125
model.train()
195-
for x, y in trainloader:
126+
for x, y in tqdm(trainloader, desc="Training"):
196127
x, y = x.to(device), y.to(device)
197-
pred = model.forward(x)
128+
logits = model.forward(x)
198129

199-
loss = criterion(y, pred)
130+
loss = criterion(logits, y)
200131
loss.backward()
201132

202133
optimizer.step()
203134
optimizer.zero_grad(set_to_none=True)
204135
trainingloss.append(loss.item())
205136

137+
preds = th.argmax(logits, dim=1)
138+
metrics(y, preds)
139+
140+
wandb.log(metrics.accumulate(str_prefix="Train "))
141+
metrics.reset()
142+
206143
evalloss = []
207144
# Eval loop start
208145
model.eval()
209146
with th.no_grad():
210-
for x, y in valiloader:
147+
for x, y in tqdm(valiloader, desc="Validation"):
211148
x, y = x.to(device), y.to(device)
212-
pred = model.forward(x)
213-
loss = criterion(y, pred)
149+
logits = model.forward(x)
150+
loss = criterion(logits, y)
214151
evalloss.append(loss.item())
215152

216-
wandb.log({
217-
'Epoch': epoch,
218-
'Train loss': np.mean(trainingloss),
219-
'Evaluation Loss': np.mean(evalloss)
220-
})
153+
preds = th.argmax(logits, dim=1)
154+
metrics(y, preds)
155+
156+
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
157+
metrics.reset()
158+
159+
wandb.log(
160+
{
161+
"Epoch": epoch,
162+
"Train loss": np.mean(trainingloss),
163+
"Evaluation Loss": np.mean(evalloss),
164+
}
165+
)
221166

222167

223-
if __name__ == '__main__':
168+
if __name__ == "__main__":
224169
main()

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[tool.isort]
2+
profile = "black"
3+
line_length = 88

0 commit comments

Comments
 (0)