Skip to content

Commit ff32432

Browse files
committed
added pyproject and mnist dataloader, although not finished yet
2 parents 6757135 + 2ac02eb commit ff32432

28 files changed

+817
-180
lines changed

.gitignore

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
__pycache__/
22
.ipynb_checkpoints/
3-
Data/
4-
Results/
5-
Experiments/
3+
Data/*
4+
Results/*
5+
Experiments/*
66
_build/
7-
bin/
7+
bin/*
8+
wandb/*
9+
wandb_api.py
10+
11+
#Magnus specific
12+
docker/*
13+
job*
814

915
# Byte-compiled / optimized / DLL files
1016
__pycache__/

doc/about.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# About this code
22

3-
Work in progress ...
3+
Work is still in progress ...

doc/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77

88
extensions = [
99
"myst_parser", # in order to use markdown
10+
"autoapi.extension", # in order to generate API documentation
1011
]
1112

13+
# search this directory for Python files
14+
autoapi_dirs = ["../utils"]
15+
1216
myst_enable_extensions = [
1317
"colon_fence", # ::: can be used instead of ``` for better rendering
1418
]

environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ dependencies:
1818
- pytest
1919
- ruff
2020
- scalene
21+
- tqdm
22+
- scipy
2123
- pip:
2224
- torch
2325
- torchvision

main.py

Lines changed: 67 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
import argparse
2-
from pathlib import Path
3-
41
import numpy as np
52
import torch as th
63
import torch.nn as nn
7-
import wandb
84
from torch.utils.data import DataLoader
5+
from torchvision import transforms
6+
from tqdm import tqdm
97

10-
from utils import MetricWrapper, createfolders, load_data, load_model
8+
import wandb
9+
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1110

1211

1312
def main():
@@ -23,122 +22,41 @@ def main():
2322
------
2423
2524
"""
26-
parser = argparse.ArgumentParser(
27-
prog="",
28-
description="",
29-
epilog="",
30-
)
31-
# Structuture related values
32-
parser.add_argument(
33-
"--datafolder",
34-
type=Path,
35-
default="Data",
36-
help="Path to where data will be saved during training.",
37-
)
38-
parser.add_argument(
39-
"--resultfolder",
40-
type=Path,
41-
default="Results",
42-
help="Path to where results will be saved during evaluation.",
43-
)
44-
parser.add_argument(
45-
"--modelfolder",
46-
type=Path,
47-
default="Experiments",
48-
help="Path to where model weights will be saved at the end of training.",
49-
)
50-
parser.add_argument(
51-
"--savemodel",
52-
type=bool,
53-
default=False,
54-
help="Whether model should be saved or not.",
55-
)
56-
57-
parser.add_argument(
58-
"--download-data",
59-
type=bool,
60-
default=False,
61-
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
62-
)
6325

64-
# Data/Model specific values
65-
parser.add_argument(
66-
"--modelname",
67-
type=str,
68-
default="MagnusModel",
69-
choices=["MagnusModel", "ChristianModel", "SolveigModel"],
70-
help="Model which to be trained on",
71-
)
72-
parser.add_argument(
73-
"--dataset",
74-
type=str,
75-
default="svhn",
76-
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
77-
help="Which dataset to train the model on.",
78-
)
79-
80-
parser.add_argument(
81-
"--metric",
82-
type=str,
83-
default=["entropy"],
84-
choices=["entropy", "f1", "recall", "precision", "accuracy"],
85-
nargs="+",
86-
help="Which metric to use for evaluation",
87-
)
88-
89-
# Training specific values
90-
parser.add_argument(
91-
"--epoch",
92-
type=int,
93-
default=20,
94-
help="Amount of training epochs the model will do.",
95-
)
96-
parser.add_argument(
97-
"--learning_rate",
98-
type=float,
99-
default=0.001,
100-
help="Learning rate parameter for model training.",
101-
)
102-
parser.add_argument(
103-
"--batchsize",
104-
type=int,
105-
default=64,
106-
help="Amount of training images loaded in one go",
107-
)
108-
parser.add_argument(
109-
"--device",
110-
type=str,
111-
default="cpu",
112-
choices=["cuda", "cpu", "mps"],
113-
help="Which device to run the training on.",
114-
)
115-
parser.add_argument(
116-
"--dry_run",
117-
action="store_true",
118-
help="If true, the code will not run the training loop.",
119-
)
120-
121-
args = parser.parse_args()
26+
args = get_args()
12227

12328
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
12429

12530
device = args.device
12631

127-
metrics = MetricWrapper(*args.metric)
32+
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
33+
augmentations = transforms.Compose(
34+
[
35+
transforms.Resize((16, 16)),
36+
transforms.ToTensor(),
37+
]
38+
)
39+
else:
40+
augmentations = transforms.Compose([transforms.ToTensor()])
12841

12942
# Dataset
13043
traindata = load_data(
13144
args.dataset,
13245
train=True,
13346
data_path=args.datafolder,
13447
download=args.download_data,
48+
transform=augmentations,
13549
)
13650
validata = load_data(
13751
args.dataset,
13852
train=False,
13953
data_path=args.datafolder,
54+
download=args.download_data,
55+
transform=augmentations,
14056
)
14157

58+
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
59+
14260
# Find the shape of the data, if is 2D, add a channel dimension
14361
data_shape = traindata[0][0].shape
14462
if len(data_shape) == 2:
@@ -168,37 +86,75 @@ def main():
16886

16987
# This allows us to load all the components without running the training loop
17088
if args.dry_run:
171-
print("Dry run completed")
172-
exit(0)
89+
dry_run_loader = DataLoader(
90+
traindata,
91+
batch_size=20,
92+
shuffle=True,
93+
pin_memory=True,
94+
drop_last=True,
95+
)
17396

174-
wandb.init(project="", tags=[])
175-
wandb.watch(model)
97+
for x, y in tqdm(dry_run_loader, desc="Dry run", total=1):
98+
x, y = x.to(device), y.to(device)
99+
logits = model.forward(x)
176100

101+
loss = criterion(logits, y)
102+
loss.backward()
103+
104+
optimizer.step()
105+
optimizer.zero_grad(set_to_none=True)
106+
107+
metrics(y, logits)
108+
109+
break
110+
print(metrics.accumulate())
111+
print("Dry run completed successfully.")
112+
exit()
113+
114+
# wandb.login(key=WANDB_API)
115+
wandb.init(
116+
entity="ColabCode-org",
117+
# entity="FYS-8805 Exam",
118+
project="Test",
119+
tags=[args.modelname, args.dataset]
120+
)
121+
wandb.watch(model)
122+
exit()
177123
for epoch in range(args.epoch):
178124
# Training loop start
179125
trainingloss = []
180126
model.train()
181-
for x, y in trainloader:
127+
for x, y in tqdm(trainloader, desc="Training"):
182128
x, y = x.to(device), y.to(device)
183-
pred = model.forward(x)
129+
logits = model.forward(x)
184130

185-
loss = criterion(y, pred)
131+
loss = criterion(logits, y)
186132
loss.backward()
187133

188134
optimizer.step()
189135
optimizer.zero_grad(set_to_none=True)
190136
trainingloss.append(loss.item())
191137

138+
metrics(y, logits)
139+
140+
wandb.log(metrics.accumulate(str_prefix="Train "))
141+
metrics.reset()
142+
192143
evalloss = []
193144
# Eval loop start
194145
model.eval()
195146
with th.no_grad():
196-
for x, y in valiloader:
147+
for x, y in tqdm(valiloader, desc="Validation"):
197148
x, y = x.to(device), y.to(device)
198-
pred = model.forward(x)
199-
loss = criterion(y, pred)
149+
logits = model.forward(x)
150+
loss = criterion(logits, y)
200151
evalloss.append(loss.item())
201152

153+
metrics(y, logits)
154+
155+
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
156+
metrics.reset()
157+
202158
wandb.log(
203159
{
204160
"Epoch": epoch,

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,6 @@ dependencies = [
2222
"torch>=2.6.0",
2323
"torchvision>=0.21.0",
2424
]
25+
[tool.isort]
26+
profile = "black"
27+
line_length = 88

tests/test_dataloaders.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,31 @@
33

44
def test_uspsdataset0_6():
55
from pathlib import Path
6-
from tempfile import TemporaryFile
6+
from tempfile import TemporaryDirectory
77

88
import h5py
99
import numpy as np
10+
from torchvision import transforms
1011

11-
with TemporaryFile() as tf:
12+
# Create a temporary directory (deleted after the test)
13+
with TemporaryDirectory() as tempdir:
14+
tempdir = Path(tempdir)
15+
16+
tf = tempdir / "usps.h5"
17+
18+
# Create a h5 file
1219
with h5py.File(tf, "w") as f:
20+
# Populate the file with data
1321
f["train/data"] = np.random.rand(10, 16 * 16)
1422
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
1523

16-
dataset = USPSDataset0_6(data_path=tf, train=True)
24+
trans = transforms.Compose(
25+
[
26+
transforms.Resize((16, 16)), # At least for USPS
27+
transforms.ToTensor(),
28+
]
29+
)
30+
dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
1731
assert len(dataset) == 10
1832
data, target = dataset[0]
1933
assert data.shape == (1, 16, 16)

tests/test_metrics.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
2-
from utils.metrics import F1Score, Precision, Recall
3-
1+
from utils.metrics import Accuracy, F1Score, Precision, Recall
42

53

64
def test_recall():
@@ -84,3 +82,18 @@ def test_for_zero_denominator():
8482
assert precision4.allclose(torch.tensor(0.0), atol=1e-5), (
8583
f"Precision Score: {precision4.item()}"
8684
)
85+
86+
87+
def test_accuracy():
88+
import torch
89+
90+
accuracy = Accuracy(num_classes=5)
91+
92+
y_true = torch.tensor([0, 3, 2, 3, 4])
93+
y_pred = torch.tensor([0, 1, 2, 3, 4])
94+
95+
accuracy_score = accuracy(y_true, y_pred)
96+
97+
assert torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5, (
98+
f"Accuracy Score: {accuracy_score.item()}"
99+
)

tests/test_models.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from utils.models import ChristianModel
4+
from utils.models import ChristianModel, JanModel
55

66

77
@pytest.mark.parametrize(
@@ -17,6 +17,19 @@ def test_christian_model(image_shape, num_classes):
1717
y = model(x)
1818

1919
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
20-
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), (
21-
f"Softmax output should sum to 1, but got: {y.sum()}"
22-
)
20+
21+
22+
@pytest.mark.parametrize(
23+
"image_shape, num_classes",
24+
[((1, 28, 28), 4), ((3, 16, 16), 10)],
25+
)
26+
def test_jan_model(image_shape, num_classes):
27+
n, c, h, w = 5, *image_shape
28+
29+
model = JanModel(image_shape, num_classes)
30+
31+
x = torch.randn(n, c, h, w)
32+
y = model(x)
33+
34+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
35+

utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__all__ = ["createfolders", "load_data", "load_model", "MetricWrapper"]
1+
__all__ = ["createfolders", "load_data", "load_model", "MetricWrapper", "get_args"]
22

3+
from .arg_parser import get_args
34
from .createfolders import createfolders
45
from .load_data import load_data
56
from .load_metric import MetricWrapper

0 commit comments

Comments
 (0)