Skip to content

Commit 24b7726

Browse files
authored
Merge pull request #28 from SFI-Visual-Intelligence/main
Sync
2 parents 5b5da1f + 4350664 commit 24b7726

24 files changed

+844
-237
lines changed

.github/workflows/format.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Format
2+
3+
on:
4+
push:
5+
paths:
6+
- 'utils/**'
7+
pull_request:
8+
paths:
9+
- 'utils/**'
10+
11+
jobs:
12+
format:
13+
name: Run Ruff and isort
14+
runs-on: ubuntu-latest
15+
16+
steps:
17+
- name: Checkout repository
18+
uses: actions/checkout@v4
19+
20+
- name: Set up Python
21+
uses: actions/setup-python@v4
22+
with:
23+
python-version: '3.x'
24+
25+
- name: Install dependencies
26+
run: |
27+
pip install ruff isort
28+
29+
- name: Run Ruff check
30+
run: |
31+
ruff check utils/
32+
33+
- name: Run isort check
34+
run: |
35+
isort --check-only utils/

environment.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ dependencies:
99
- sphinx-autobuild
1010
- sphinx-rtd-theme
1111
- pip
12+
- h5py
13+
- black
14+
- isort
15+
- jupyterlab
16+
- numpy
17+
- pandas
1218
- pytest
19+
- ruff
20+
- scalene
1321
prefix: /opt/miniconda3/envs/cc-exam
1422

main.py

Lines changed: 160 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,158 @@
1-
import torch as th
2-
import torch.nn as nn
3-
from torch.utils.data import DataLoader
41
import argparse
5-
import wandb
2+
from pathlib import Path
3+
64
import numpy as np
7-
from utils import MetricWrapper, load_model, load_data, createfolders
5+
import torch as th
6+
import torch.nn as nn
7+
import wandb
8+
from torch.utils.data import DataLoader
9+
10+
from utils import MetricWrapper, createfolders, load_data, load_model
811

912

1013
def main():
11-
'''
12-
14+
"""
15+
1316
Parameters
1417
----------
15-
18+
1619
Returns
1720
-------
18-
21+
1922
Raises
2023
------
21-
22-
'''
24+
25+
"""
2326
parser = argparse.ArgumentParser(
24-
prog='',
25-
description='',
26-
epilog='',
27-
)
28-
#Structuture related values
29-
parser.add_argument('--datafolder', type=str, default='Data/', help='Path to where data will be saved during training.')
30-
parser.add_argument('--resultfolder', type=str, default='Results/', help='Path to where results will be saved during evaluation.')
31-
parser.add_argument('--modelfolder', type=str, default='Experiments/', help='Path to where model weights will be saved at the end of training.')
32-
parser.add_argument('--savemodel', type=bool, default=False, help='Whether model should be saved or not.')
33-
34-
parser.add_argument('--download-data', type=bool, default=False, help='Whether the data should be downloaded or not. Might cause code to start a bit slowly.')
35-
36-
#Data/Model specific values
37-
parser.add_argument('--modelname', type=str, default='MagnusModel',
38-
choices = ['MagnusModel'], help="Model which to be trained on")
39-
parser.add_argument('--dataset', type=str, default='svhn',
40-
choices=['svhn'], help='Which dataset to train the model on.')
41-
42-
parser.add_argument('--EntropyPrediction', type=bool, default=True, help='Include the Entropy Prediction metric in evaluation')
43-
parser.add_argument('--F1Score', type=bool, default=True, help='Include the F1Score metric in evaluation')
44-
parser.add_argument('--Recall', type=bool, default=True, help='Include the Recall metric in evaluation')
45-
parser.add_argument('--Precision', type=bool, default=True, help='Include the Precision metric in evaluation')
46-
parser.add_argument('--Accuracy', type=bool, default=True, help='Include the Accuracy metric in evaluation')
47-
48-
#Training specific values
49-
parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do.')
50-
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate parameter for model training.')
51-
parser.add_argument('--batchsize', type=int, default=64, help='Amount of training images loaded in one go')
52-
27+
prog="",
28+
description="",
29+
epilog="",
30+
)
31+
# Structuture related values
32+
parser.add_argument(
33+
"--datafolder",
34+
type=Path,
35+
default="Data",
36+
help="Path to where data will be saved during training.",
37+
)
38+
parser.add_argument(
39+
"--resultfolder",
40+
type=Path,
41+
default="Results",
42+
help="Path to where results will be saved during evaluation.",
43+
)
44+
parser.add_argument(
45+
"--modelfolder",
46+
type=Path,
47+
default="Experiments",
48+
help="Path to where model weights will be saved at the end of training.",
49+
)
50+
parser.add_argument(
51+
"--savemodel",
52+
type=bool,
53+
default=False,
54+
help="Whether model should be saved or not.",
55+
)
56+
57+
parser.add_argument(
58+
"--download-data",
59+
type=bool,
60+
default=False,
61+
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
62+
)
63+
64+
# Data/Model specific values
65+
parser.add_argument(
66+
"--modelname",
67+
type=str,
68+
default="MagnusModel",
69+
choices=["MagnusModel", "ChristianModel"],
70+
help="Model which to be trained on",
71+
)
72+
parser.add_argument(
73+
"--dataset",
74+
type=str,
75+
default="svhn",
76+
choices=["svhn", "usps_0-6"],
77+
help="Which dataset to train the model on.",
78+
)
79+
80+
parser.add_argument(
81+
"--metric",
82+
type=str,
83+
default=["entropy"],
84+
choices=["entropy", "f1", "recall", "precision", "accuracy"],
85+
nargs="+",
86+
help="Which metric to use for evaluation",
87+
)
88+
89+
# Training specific values
90+
parser.add_argument(
91+
"--epoch",
92+
type=int,
93+
default=20,
94+
help="Amount of training epochs the model will do.",
95+
)
96+
parser.add_argument(
97+
"--learning_rate",
98+
type=float,
99+
default=0.001,
100+
help="Learning rate parameter for model training.",
101+
)
102+
parser.add_argument(
103+
"--batchsize",
104+
type=int,
105+
default=64,
106+
help="Amount of training images loaded in one go",
107+
)
108+
parser.add_argument(
109+
"--device",
110+
type=str,
111+
default="cpu",
112+
choices=["cuda", "cpu", "mps"],
113+
help="Which device to run the training on.",
114+
)
115+
parser.add_argument(
116+
"--dry_run",
117+
action="store_true",
118+
help="If true, the code will not run the training loop.",
119+
)
120+
53121
args = parser.parse_args()
54-
55-
56-
createfolders(args)
57-
58-
device = 'cuda' if th.cuda.is_available() else 'cpu'
59-
60-
#load model
61-
model = load_model()
122+
123+
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
124+
125+
device = args.device
126+
127+
metrics = MetricWrapper(*args.metric)
128+
129+
# Dataset
130+
traindata = load_data(
131+
args.dataset,
132+
train=True,
133+
data_path=args.datafolder,
134+
download=args.download_data,
135+
)
136+
validata = load_data(
137+
args.dataset,
138+
train=False,
139+
data_path=args.datafolder,
140+
)
141+
142+
# Find number of channels in the dataset
143+
if len(traindata[0][0].shape) == 2:
144+
channels = 1
145+
else:
146+
channels = traindata[0][0].shape[0]
147+
148+
# load model
149+
model = load_model(
150+
args.modelname,
151+
in_channels=channels,
152+
num_classes=traindata.num_classes,
153+
)
62154
model.to(device)
63-
64-
metrics = MetricWrapper(
65-
EntropyPred = args.EntropyPrediction,
66-
F1Score = args.F1Score,
67-
Recall = args.Recall,
68-
Precision = args.Precision,
69-
Accuracy = args.Accuracy
70-
)
71-
72-
#Dataset
73-
traindata = load_data(args.dataset)
74-
validata = load_data(args.dataset)
75-
155+
76156
trainloader = DataLoader(traindata,
77157
batch_size=args.batchsize,
78158
shuffle=True,
@@ -82,48 +162,51 @@ def main():
82162
batch_size=args.batchsize,
83163
shuffle=False,
84164
pin_memory=True)
85-
165+
86166
criterion = nn.CrossEntropyLoss()
87-
optimizer = th.optim.Adam(model.parameters(), lr = args.learning_rate)
88-
89-
167+
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
168+
169+
# This allows us to load all the components without running the training loop
170+
if args.dry_run:
171+
print("Dry run completed")
172+
exit(0)
173+
90174
wandb.init(project='',
91175
tags=[])
92176
wandb.watch(model)
93-
177+
94178
for epoch in range(args.epoch):
95-
96-
#Training loop start
179+
180+
# Training loop start
97181
trainingloss = []
98182
model.train()
99-
for x, y in traindata:
183+
for x, y in trainloader:
100184
x, y = x.to(device), y.to(device)
101185
pred = model.forward(x)
102-
186+
103187
loss = criterion(y, pred)
104188
loss.backward()
105-
189+
106190
optimizer.step()
107191
optimizer.zero_grad(set_to_none=True)
108192
trainingloss.append(loss.item())
109-
193+
110194
evalloss = []
111-
#Eval loop start
195+
# Eval loop start
112196
model.eval()
113197
with th.no_grad():
114198
for x, y in valiloader:
115-
x = x.to(device)
199+
x, y = x.to(device), y.to(device)
116200
pred = model.forward(x)
117201
loss = criterion(y, pred)
118202
evalloss.append(loss.item())
119-
203+
120204
wandb.log({
121205
'Epoch': epoch,
122206
'Train loss': np.mean(trainingloss),
123207
'Evaluation Loss': np.mean(evalloss)
124208
})
125-
126209

127210

128211
if __name__ == '__main__':
129-
main()
212+
main()

0 commit comments

Comments
 (0)