Skip to content

Commit 8489dc1

Browse files
authored
Merge pull request #94 from SFI-Visual-Intelligence/johan/devbranch
Updatet main to work with metricWrapper.
2 parents 5f67599 + cc9dce0 commit 8489dc1

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ local*
1919

2020
# Johanthings
2121
formatting.x
22+
testrun.x
23+
storage/
2224

2325
# Byte-compiled / optimized / DLL files
2426
__pycache__/

CollaborativeCoding/dataloaders/mnist_4_9.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ def __init__(
4343
self.labels_path = self.mnist_path / (
4444
MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1]
4545
)
46-
46+
47+
# Functions to map the labels from (4,9) -> (0,5) for CrossEntropyLoss to work properly.
48+
self.label_shift = lambda x: x-4
49+
self.label_restore = lambda x: x+4
50+
51+
4752
def __len__(self):
4853
return len(self.samples)
4954

@@ -66,4 +71,4 @@ def __getitem__(self, idx):
6671
if self.transform:
6772
image = self.transform(image)
6873

69-
return image, label
74+
return image, self.label_shift(label)

main.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,12 @@ def main():
139139

140140
for epoch in range(args.epoch):
141141
# Training loop start
142+
print(f"Epoch: {epoch+1}/{args.epoch}")
142143
trainingloss = []
143144
model.train()
144145
for x, y in tqdm(trainloader, desc="Training"):
145146
x, y = x.to(device), y.to(device)
146147
logits = model.forward(x)
147-
148148
loss = criterion(logits, y)
149149
loss.backward()
150150

@@ -172,8 +172,8 @@ def main():
172172
"Train loss": np.mean(trainingloss),
173173
"Validation loss": np.mean(valloss),
174174
}
175-
| train_metrics.getmetric(str_prefix="Train ")
176-
| val_metrics.getmetric(str_prefix="Validation ")
175+
| train_metrics.getmetrics(str_prefix="Train ")
176+
| val_metrics.getmetrics(str_prefix="Validation ")
177177
)
178178
train_metrics.resetmetric()
179179
val_metrics.resetmetric()
@@ -187,12 +187,11 @@ def main():
187187
loss = criterion(logits, y)
188188
testloss.append(loss.item())
189189

190-
preds = th.argmax(logits, dim=1)
191-
test_metrics(y, preds)
190+
test_metrics(y, logits)
192191

193192
wandb.log(
194193
{"Epoch": 1, "Test loss": np.mean(testloss)}
195-
| test_metrics.getmetric(str_prefix="Test ")
194+
| test_metrics.getmetrics(str_prefix="Test ")
196195
)
197196
test_metrics.resetmetric()
198197

0 commit comments

Comments
 (0)