Skip to content

Commit a412297

Browse files
committed
Merge commit 'refs/pull/54/head' of github.com:SFI-Visual-Intelligence/Collaborative-Coding-Exam into johan/devbranch
2 parents ff32432 + a7d51c4 commit a412297

File tree

17 files changed

+599
-383
lines changed

17 files changed

+599
-383
lines changed

environment.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ dependencies:
99
- sphinx-autobuild
1010
- sphinx-rtd-theme
1111
- pip
12-
- h5py
12+
- h5py==3.12.1
13+
- hdf5==1.14.4
1314
- black
1415
- isort
1516
- jupyterlab
@@ -20,6 +21,7 @@ dependencies:
2021
- scalene
2122
- tqdm
2223
- scipy
24+
- wandb
2325
- pip:
2426
- torch
2527
- torchvision

main.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import wandb
99
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
10+
from wandb_api import WANDB_API
1011

1112

1213
def main():
@@ -29,33 +30,38 @@ def main():
2930

3031
device = args.device
3132

32-
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
33-
augmentations = transforms.Compose(
33+
if args.dataset.lower() in ["usps_0-6", "usps_7-9"]:
34+
transform = transforms.Compose(
3435
[
3536
transforms.Resize((16, 16)),
3637
transforms.ToTensor(),
3738
]
3839
)
3940
else:
40-
augmentations = transforms.Compose([transforms.ToTensor()])
41+
transform = transforms.Compose([transforms.ToTensor()])
4142

42-
# Dataset
43-
traindata = load_data(
43+
traindata, validata, testdata = load_data(
4444
args.dataset,
45-
train=True,
46-
data_path=args.datafolder,
47-
download=args.download_data,
48-
transform=augmentations,
49-
)
50-
validata = load_data(
51-
args.dataset,
52-
train=False,
53-
data_path=args.datafolder,
54-
download=args.download_data,
55-
transform=augmentations,
45+
data_dir=args.datafolder,
46+
transform=transform,
47+
val_size=args.val_size,
5648
)
5749

58-
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
50+
train_metrics = MetricWrapper(
51+
*args.metric,
52+
num_classes=traindata.num_classes,
53+
macro_averaging=args.macro_averaging,
54+
)
55+
val_metrics = MetricWrapper(
56+
*args.metric,
57+
num_classes=traindata.num_classes,
58+
macro_averaging=args.macro_averaging,
59+
)
60+
test_metrics = MetricWrapper(
61+
*args.metric,
62+
num_classes=traindata.num_classes,
63+
macro_averaging=args.macro_averaging,
64+
)
5965

6066
# Find the shape of the data, if is 2D, add a channel dimension
6167
data_shape = traindata[0][0].shape
@@ -80,6 +86,9 @@ def main():
8086
valiloader = DataLoader(
8187
validata, batch_size=args.batchsize, shuffle=False, pin_memory=True
8288
)
89+
testloader = DataLoader(
90+
testdata, batch_size=args.batchsize, shuffle=False, pin_memory=True
91+
)
8392

8493
criterion = nn.CrossEntropyLoss()
8594
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
@@ -104,22 +113,22 @@ def main():
104113
optimizer.step()
105114
optimizer.zero_grad(set_to_none=True)
106115

107-
metrics(y, logits)
116+
train_metrics(y, logits)
108117

109118
break
110-
print(metrics.accumulate())
119+
print(train_metrics.accumulate())
111120
print("Dry run completed successfully.")
112121
exit()
113122

114123
# wandb.login(key=WANDB_API)
115124
wandb.init(
116-
entity="ColabCode-org",
117-
# entity="FYS-8805 Exam",
118-
project="Test",
119-
tags=[args.modelname, args.dataset]
120-
)
125+
entity="ColabCode",
126+
# entity="FYS-8805 Exam",
127+
project="Jan",
128+
tags=[args.modelname, args.dataset],
129+
)
121130
wandb.watch(model)
122-
exit()
131+
123132
for epoch in range(args.epoch):
124133
# Training loop start
125134
trainingloss = []
@@ -135,33 +144,49 @@ def main():
135144
optimizer.zero_grad(set_to_none=True)
136145
trainingloss.append(loss.item())
137146

138-
metrics(y, logits)
139-
140-
wandb.log(metrics.accumulate(str_prefix="Train "))
141-
metrics.reset()
147+
train_metrics(y, logits)
142148

143-
evalloss = []
144-
# Eval loop start
149+
valloss = []
150+
# Validation loop start
145151
model.eval()
146152
with th.no_grad():
147153
for x, y in tqdm(valiloader, desc="Validation"):
148154
x, y = x.to(device), y.to(device)
149155
logits = model.forward(x)
150156
loss = criterion(logits, y)
151-
evalloss.append(loss.item())
152-
153-
metrics(y, logits)
157+
valloss.append(loss.item())
154158

155-
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
156-
metrics.reset()
159+
val_metrics(y, logits)
157160

158161
wandb.log(
159162
{
160163
"Epoch": epoch,
161164
"Train loss": np.mean(trainingloss),
162-
"Evaluation Loss": np.mean(evalloss),
165+
"Validation loss": np.mean(valloss),
163166
}
167+
| train_metrics.accumulate(str_prefix="Train ")
168+
| val_metrics.accumulate(str_prefix="Validation ")
164169
)
170+
train_metrics.reset()
171+
val_metrics.reset()
172+
173+
testloss = []
174+
model.eval()
175+
with th.no_grad():
176+
for x, y in tqdm(testloader, desc="Testing"):
177+
x, y = x.to(device), y.to(device)
178+
logits = model.forward(x)
179+
loss = criterion(logits, y)
180+
testloss.append(loss.item())
181+
182+
preds = th.argmax(logits, dim=1)
183+
test_metrics(y, preds)
184+
185+
wandb.log(
186+
{"Epoch": 1, "Test loss": np.mean(testloss)}
187+
| test_metrics.accumulate(str_prefix="Test ")
188+
)
189+
test_metrics.reset()
165190

166191

167192
if __name__ == "__main__":

tests/test_dataloaders.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,25 @@ def test_uspsdataset0_6():
1717

1818
# Create a h5 file
1919
with h5py.File(tf, "w") as f:
20+
targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
21+
indices = np.arange(len(targets))
2022
# Populate the file with data
2123
f["train/data"] = np.random.rand(10, 16 * 16)
22-
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
24+
f["train/target"] = targets
2325

2426
trans = transforms.Compose(
2527
[
26-
transforms.Resize((16, 16)), # At least for USPS
28+
transforms.Resize((16, 16)),
2729
transforms.ToTensor(),
2830
]
2931
)
30-
dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
32+
dataset = USPSDataset0_6(
33+
data_path=tempdir,
34+
sample_ids=indices,
35+
train=True,
36+
transform=trans,
37+
)
3138
assert len(dataset) == 10
3239
data, target = dataset[0]
3340
assert data.shape == (1, 16, 16)
34-
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
41+
assert target == 6

tests/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_f1score():
2626

2727
target = torch.tensor([0, 1, 0, 2])
2828

29-
f1_metric.update(preds, target)
29+
f1_metric(preds, target)
3030
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
3131
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
3232
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."

tests/test_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,3 @@ def test_jan_model(image_shape, num_classes):
3232
y = model(x)
3333

3434
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
35-

utils/arg_parser.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,6 @@ def get_args():
3333
help="Whether model should be saved or not.",
3434
)
3535

36-
parser.add_argument(
37-
"--download-data",
38-
action="store_true",
39-
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
40-
)
41-
4236
# Data/Model specific values
4337
parser.add_argument(
4438
"--modelname",
@@ -60,7 +54,12 @@ def get_args():
6054
choices=["svhn", "usps_0-6", "usps_7-9", "mnist_0-3", "mnist_4-9"],
6155
help="Which dataset to train the model on.",
6256
)
63-
57+
parser.add_argument(
58+
"--val_size",
59+
type=float,
60+
default=0.2,
61+
help="Percentage of training dataset to be used as validation dataset - must be within (0,1).",
62+
)
6463
parser.add_argument(
6564
"--metric",
6665
type=str,
@@ -69,6 +68,11 @@ def get_args():
6968
nargs="+",
7069
help="Which metric to use for evaluation",
7170
)
71+
parser.add_argument(
72+
"--macro_averaging",
73+
action="store_true",
74+
help="If the flag is included, the metrics will be calculated using macro averaging.",
75+
)
7276

7377
# Training specific values
7478
parser.add_argument(
@@ -99,7 +103,7 @@ def get_args():
99103
parser.add_argument(
100104
"--dry_run",
101105
action="store_true",
102-
help="If true, the code will not run the training loop.",
106+
help="If the flag is included, the code will not run the training loop.",
103107
)
104108
args = parser.parse_args()
105109

utils/dataloaders/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3", "SVHNDataset"]
1+
__all__ = [
2+
"USPSDataset0_6",
3+
"USPSH5_Digit_7_9_Dataset",
4+
"MNISTDataset0_3",
5+
"Downloader",
6+
"SVHNDataset",
7+
]
28

9+
from .download import Downloader
310
from .mnist_0_3 import MNISTDataset0_3
11+
from .svhn import SVHNDataset
412
from .usps_0_6 import USPSDataset0_6
513
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
6-
from .svhn import SVHNDataset

utils/dataloaders/datasources.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,26 @@
1717
"8ea070ee2aca1ac39742fdd1ef5ed118",
1818
],
1919
}
20+
21+
MNIST_SOURCE = {
22+
"train_images": [
23+
"https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
24+
"train-images-idx3-ubyte",
25+
None,
26+
],
27+
"train_labels": [
28+
"https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
29+
"train-labels-idx1-ubyte",
30+
None,
31+
],
32+
"test_images": [
33+
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
34+
"t10k-images-idx3-ubyte",
35+
None,
36+
],
37+
"test_labels": [
38+
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
39+
"t10k-labels-idx1-ubyte",
40+
None,
41+
],
42+
}

0 commit comments

Comments
 (0)