Skip to content

Commit a6b436c

Browse files
committed
added num_classes parameter to Metric wrapper and metric calculation step to train, val loops
1 parent c25a2c8 commit a6b436c

File tree

4 files changed

+46
-22
lines changed

4 files changed

+46
-22
lines changed

main.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def main():
3131

3232
device = args.device
3333

34-
metrics = MetricWrapper(*args.metric)
3534

3635
if args.dataset.lower() == "usps_0-6" or args.dataset.lower() == "uspsh5_7_9":
3736
augmentations = transforms.Compose(
@@ -59,6 +58,8 @@ def main():
5958
transform=augmentations,
6059
)
6160

61+
metrics = MetricWrapper(*args.metric, num_classes = traindata.num_classes)
62+
6263
# Find the shape of the data, if is 2D, add a channel dimension
6364
data_shape = traindata[0][0].shape
6465
if len(data_shape) == 2:
@@ -90,24 +91,28 @@ def main():
9091
if args.dry_run:
9192
dry_run_loader = DataLoader(
9293
traindata,
93-
batch_size=1,
94+
batch_size=20,
9495
shuffle=True,
9596
pin_memory=True,
9697
drop_last=True,
9798
)
9899

99100
for x, y in tqdm(dry_run_loader, desc="Dry run", total=1):
100101
x, y = x.to(device), y.to(device)
101-
pred = model.forward(x)
102+
logits = model.forward(x)
102103

103-
loss = criterion(y, pred)
104+
loss = criterion(logits, y)
104105
loss.backward()
105106

106107
optimizer.step()
107108
optimizer.zero_grad(set_to_none=True)
109+
110+
preds = th.argmax(logits, dim=1)
111+
metrics(y, preds)
108112

109-
break
110113

114+
break
115+
print(metrics.__getmetrics__())
111116
print("Dry run completed successfully.")
112117
exit(0)
113118

@@ -120,24 +125,36 @@ def main():
120125
model.train()
121126
for x, y in tqdm(trainloader, desc="Training"):
122127
x, y = x.to(device), y.to(device)
123-
pred = model.forward(x)
128+
logits = model.forward(x)
124129

125-
loss = criterion(y, pred)
130+
loss = criterion(logits, y)
126131
loss.backward()
127132

128133
optimizer.step()
129134
optimizer.zero_grad(set_to_none=True)
130135
trainingloss.append(loss.item())
136+
137+
preds = th.argmax(logits, dim=1)
138+
metrics(y, preds)
139+
140+
wandb.log(metrics.__getmetrics__(str_prefix="Train "))
141+
metrics.__resetvalues__()
131142

132143
evalloss = []
133144
# Eval loop start
134145
model.eval()
135146
with th.no_grad():
136147
for x, y in tqdm(valiloader, desc="Validation"):
137148
x, y = x.to(device), y.to(device)
138-
pred = model.forward(x)
139-
loss = criterion(y, pred)
149+
logits = model.forward(x)
150+
loss = criterion(y, logits)
140151
evalloss.append(loss.item())
152+
153+
preds = th.argmax(logits, dim=1)
154+
metrics(y, preds)
155+
156+
wandb.log(metrics.__getmetrics__(str_prefix="Evaluation "))
157+
metrics.__resetvalues__()
141158

142159
wandb.log(
143160
{

utils/dataloaders/mnist_0_3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,11 @@ def __len__(self):
134134

135135
def __getitem__(self, index):
136136
with open(self.labels_path, "rb") as f:
137-
f.seek(8 + index) # Jump to the label position
137+
f.seek(8 + self.idx[index]) # Jump to the label position
138138
label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label
139139

140140
with open(self.images_path, "rb") as f:
141-
f.seek(16 + index * 28 * 28) # Jump to image position
141+
f.seek(16 + self.idx[index] * 28 * 28) # Jump to image position
142142
image = np.frombuffer(f.read(28 * 28), dtype=np.uint8).reshape(
143143
28, 28
144144
) # Read image data
@@ -149,3 +149,5 @@ def __getitem__(self, index):
149149
image = self.transform(image)
150150

151151
return image, label
152+
153+

utils/load_metric.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import numpy as np
44
import torch.nn as nn
55

6-
from .metrics import Accuracy, EntropyPrediction, F1Score, Precision
6+
from .metrics import Accuracy, EntropyPrediction, F1Score, Precision, Recall
77

88

99
class MetricWrapper(nn.Module):
10-
def __init__(self, *metrics):
10+
def __init__(self, *metrics, num_classes):
1111
super().__init__()
1212
self.metrics = {}
13+
self.num_classes = num_classes
1314

1415
for metric in metrics:
1516
self.metrics[metric] = self._get_metric(metric)
@@ -33,26 +34,29 @@ def _get_metric(self, key):
3334

3435
match key.lower():
3536
case "entropy":
36-
return EntropyPrediction()
37+
return EntropyPrediction(num_classes=self.num_classes)
3738
case "f1":
38-
raise F1Score()
39+
return F1Score(num_classes=self.num_classes)
3940
case "recall":
40-
raise NotImplementedError("Recall score not implemented yet")
41+
return Recall(num_classes=self.num_classes)
4142
case "precision":
42-
return Precision()
43+
return Precision(num_classes=self.num_classes)
4344
case "accuracy":
44-
return Accuracy()
45+
return Accuracy(num_classes=self.num_classes)
4546
case _:
4647
raise ValueError(f"Metric {key} not supported")
4748

4849
def __call__(self, y_true, y_pred):
4950
for key in self.metrics:
5051
self.tmp_scores[key].append(self.metrics[key](y_true, y_pred))
5152

52-
def __getmetrics__(self):
53+
def __getmetrics__(self, str_prefix: str = None):
5354
return_metrics = {}
5455
for key in self.metrics:
55-
return_metrics[key] = np.mean(self.tmp_scores[key])
56+
if str_prefix is not None:
57+
return_metrics[str_prefix + key] = np.mean(self.tmp_scores[key])
58+
else:
59+
return_metrics[key] = np.mean(self.tmp_scores[key])
5660

5761
return return_metrics
5862

utils/metrics/accuracy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44

55
class Accuracy(nn.Module):
6-
def __init__(self):
6+
def __init__(self, num_classes):
77
super().__init__()
8-
8+
self.num_classes = num_classes
9+
910
def forward(self, y_true, y_pred):
1011
"""
1112
Compute the accuracy of the model.

0 commit comments

Comments
 (0)