Skip to content

Commit 77b5002

Browse files
authored
Merge pull request #54 from SFI-Visual-Intelligence/Jan-metrics
Added micro/macro averaging option to MetricsWrapper
2 parents 75b1801 + 97750d8 commit 77b5002

24 files changed

+3270
-478
lines changed

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.12

environment.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ dependencies:
99
- sphinx-autobuild
1010
- sphinx-rtd-theme
1111
- pip
12-
- h5py
12+
- h5py==3.12.1
13+
- hdf5==1.14.4
1314
- black
1415
- isort
1516
- jupyterlab
@@ -20,6 +21,8 @@ dependencies:
2021
- scalene
2122
- tqdm
2223
- scipy
24+
- wandb
25+
- scikit-learn
2326
- pip:
2427
- torch
2528
- torchvision

main.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import wandb
99
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
10+
from wandb_api import WANDB_API
1011

1112

1213
def main():
@@ -29,33 +30,38 @@ def main():
2930

3031
device = args.device
3132

32-
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
33-
augmentations = transforms.Compose(
33+
if args.dataset.lower() in ["usps_0-6", "usps_7-9"]:
34+
transform = transforms.Compose(
3435
[
3536
transforms.Resize((16, 16)),
3637
transforms.ToTensor(),
3738
]
3839
)
3940
else:
40-
augmentations = transforms.Compose([transforms.ToTensor()])
41+
transform = transforms.Compose([transforms.ToTensor()])
4142

42-
# Dataset
43-
traindata = load_data(
43+
traindata, validata, testdata = load_data(
4444
args.dataset,
45-
train=True,
46-
data_path=args.datafolder,
47-
download=args.download_data,
48-
transform=augmentations,
49-
)
50-
validata = load_data(
51-
args.dataset,
52-
train=False,
53-
data_path=args.datafolder,
54-
download=args.download_data,
55-
transform=augmentations,
45+
data_dir=args.datafolder,
46+
transform=transform,
47+
val_size=args.val_size,
5648
)
5749

58-
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
50+
train_metrics = MetricWrapper(
51+
*args.metric,
52+
num_classes=traindata.num_classes,
53+
macro_averaging=args.macro_averaging,
54+
)
55+
val_metrics = MetricWrapper(
56+
*args.metric,
57+
num_classes=traindata.num_classes,
58+
macro_averaging=args.macro_averaging,
59+
)
60+
test_metrics = MetricWrapper(
61+
*args.metric,
62+
num_classes=traindata.num_classes,
63+
macro_averaging=args.macro_averaging,
64+
)
5965

6066
# Find the shape of the data, if is 2D, add a channel dimension
6167
data_shape = traindata[0][0].shape
@@ -80,6 +86,9 @@ def main():
8086
valiloader = DataLoader(
8187
validata, batch_size=args.batchsize, shuffle=False, pin_memory=True
8288
)
89+
testloader = DataLoader(
90+
testdata, batch_size=args.batchsize, shuffle=False, pin_memory=True
91+
)
8392

8493
criterion = nn.CrossEntropyLoss()
8594
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
@@ -104,22 +113,22 @@ def main():
104113
optimizer.step()
105114
optimizer.zero_grad(set_to_none=True)
106115

107-
metrics(y, logits)
116+
train_metrics(y, logits)
108117

109118
break
110-
print(metrics.accumulate())
119+
print(train_metrics.accumulate())
111120
print("Dry run completed successfully.")
112121
exit()
113122

114123
# wandb.login(key=WANDB_API)
115124
wandb.init(
116-
entity="ColabCode-org",
117-
# entity="FYS-8805 Exam",
118-
project="Test",
119-
tags=[args.modelname, args.dataset]
120-
)
125+
entity="ColabCode",
126+
# entity="FYS-8805 Exam",
127+
project="Jan",
128+
tags=[args.modelname, args.dataset],
129+
)
121130
wandb.watch(model)
122-
exit()
131+
123132
for epoch in range(args.epoch):
124133
# Training loop start
125134
trainingloss = []
@@ -135,33 +144,49 @@ def main():
135144
optimizer.zero_grad(set_to_none=True)
136145
trainingloss.append(loss.item())
137146

138-
metrics(y, logits)
139-
140-
wandb.log(metrics.accumulate(str_prefix="Train "))
141-
metrics.reset()
147+
train_metrics(y, logits)
142148

143-
evalloss = []
144-
# Eval loop start
149+
valloss = []
150+
# Validation loop start
145151
model.eval()
146152
with th.no_grad():
147153
for x, y in tqdm(valiloader, desc="Validation"):
148154
x, y = x.to(device), y.to(device)
149155
logits = model.forward(x)
150156
loss = criterion(logits, y)
151-
evalloss.append(loss.item())
152-
153-
metrics(y, logits)
157+
valloss.append(loss.item())
154158

155-
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
156-
metrics.reset()
159+
val_metrics(y, logits)
157160

158161
wandb.log(
159162
{
160163
"Epoch": epoch,
161164
"Train loss": np.mean(trainingloss),
162-
"Evaluation Loss": np.mean(evalloss),
165+
"Validation loss": np.mean(valloss),
163166
}
167+
| train_metrics.accumulate(str_prefix="Train ")
168+
| val_metrics.accumulate(str_prefix="Validation ")
164169
)
170+
train_metrics.reset()
171+
val_metrics.reset()
172+
173+
testloss = []
174+
model.eval()
175+
with th.no_grad():
176+
for x, y in tqdm(testloader, desc="Testing"):
177+
x, y = x.to(device), y.to(device)
178+
logits = model.forward(x)
179+
loss = criterion(logits, y)
180+
testloss.append(loss.item())
181+
182+
preds = th.argmax(logits, dim=1)
183+
test_metrics(y, preds)
184+
185+
wandb.log(
186+
{"Epoch": 1, "Test loss": np.mean(testloss)}
187+
| test_metrics.accumulate(str_prefix="Test ")
188+
)
189+
test_metrics.reset()
165190

166191

167192
if __name__ == "__main__":

pyproject.toml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,29 @@
1+
[project]
2+
name = "collaborative-coding-exam"
3+
version = "0.1.0"
4+
description = "Exam project in the collaborative coding course."
5+
readme = "README.md"
6+
requires-python = ">=3.12"
7+
dependencies = [
8+
"black>=25.1.0",
9+
"h5py>=3.12.1",
10+
"isort>=6.0.0",
11+
"jupyterlab>=4.3.5",
12+
"numpy>=2.2.2",
13+
"pandas>=2.2.3",
14+
"pip>=25.0",
15+
"pytest>=8.3.4",
16+
"ruff>=0.9.4",
17+
"scalene>=1.5.51",
18+
"scikit-learn>=1.6.1",
19+
"sphinx>=8.1.3",
20+
"sphinx-autoapi>=3.4.0",
21+
"sphinx-autobuild>=2024.10.3",
22+
"sphinx-rtd-theme>=3.0.2",
23+
"torch>=2.6.0",
24+
"torchvision>=0.21.0",
25+
"tqdm>=4.67.1",
26+
]
127
[tool.isort]
228
profile = "black"
329
line_length = 88

tests/test_dataloaders.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,25 @@ def test_uspsdataset0_6():
1717

1818
# Create a h5 file
1919
with h5py.File(tf, "w") as f:
20+
targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
21+
indices = np.arange(len(targets))
2022
# Populate the file with data
2123
f["train/data"] = np.random.rand(10, 16 * 16)
22-
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
24+
f["train/target"] = targets
2325

2426
trans = transforms.Compose(
2527
[
26-
transforms.Resize((16, 16)), # At least for USPS
28+
transforms.Resize((16, 16)),
2729
transforms.ToTensor(),
2830
]
2931
)
30-
dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
32+
dataset = USPSDataset0_6(
33+
data_path=tempdir,
34+
sample_ids=indices,
35+
train=True,
36+
transform=trans,
37+
)
3138
assert len(dataset) == 10
3239
data, target = dataset[0]
3340
assert data.shape == (1, 16, 16)
34-
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
41+
assert target == 6

0 commit comments

Comments
 (0)