Skip to content

Commit 19a6ea1

Browse files
committed
hopefully fixed f1 test
1 parent 0d5fc20 commit 19a6ea1

File tree

3 files changed

+37
-33
lines changed

3 files changed

+37
-33
lines changed

main.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1010
from wandb_api import WANDB_API
1111

12+
1213
def main():
1314
"""
1415
@@ -46,7 +47,21 @@ def main():
4647
val_size=args.val_size,
4748
)
4849

49-
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes, macro_averaging=args.macro_averaging)
50+
train_metrics = MetricWrapper(
51+
*args.metric,
52+
num_classes=traindata.num_classes,
53+
macro_averaging=args.macro_averaging,
54+
)
55+
val_metrics = MetricWrapper(
56+
*args.metric,
57+
num_classes=traindata.num_classes,
58+
macro_averaging=args.macro_averaging,
59+
)
60+
test_metrics = MetricWrapper(
61+
*args.metric,
62+
num_classes=traindata.num_classes,
63+
macro_averaging=args.macro_averaging,
64+
)
5065

5166
# Find the shape of the data, if is 2D, add a channel dimension
5267
data_shape = traindata[0][0].shape
@@ -98,22 +113,22 @@ def main():
98113
optimizer.step()
99114
optimizer.zero_grad(set_to_none=True)
100115

101-
metrics(y, logits)
116+
train_metrics(y, logits)
102117

103118
break
104-
print(metrics.accumulate())
119+
print(train_metrics.accumulate())
105120
print("Dry run completed successfully.")
106121
exit()
107122

108123
# wandb.login(key=WANDB_API)
109124
wandb.init(
110-
entity="ColabCode",
111-
# entity="FYS-8805 Exam",
112-
project="Jan",
113-
tags=[args.modelname, args.dataset]
114-
)
125+
entity="ColabCode",
126+
# entity="FYS-8805 Exam",
127+
project="Jan",
128+
tags=[args.modelname, args.dataset],
129+
)
115130
wandb.watch(model)
116-
131+
117132
for epoch in range(args.epoch):
118133
# Training loop start
119134
trainingloss = []
@@ -129,10 +144,7 @@ def main():
129144
optimizer.zero_grad(set_to_none=True)
130145
trainingloss.append(loss.item())
131146

132-
metrics(y, logits)
133-
134-
wandb.log(metrics.accumulate(str_prefix="Train "))
135-
metrics.reset()
147+
train_metrics(y, logits)
136148

137149
valloss = []
138150
# Validation loop start
@@ -144,18 +156,19 @@ def main():
144156
loss = criterion(logits, y)
145157
valloss.append(loss.item())
146158

147-
metrics(y, logits)
148-
149-
wandb.log(metrics.accumulate(str_prefix="Validation "))
150-
metrics.reset()
159+
val_metrics(y, logits)
151160

152161
wandb.log(
153162
{
154163
"Epoch": epoch,
155164
"Train loss": np.mean(trainingloss),
156165
"Validation loss": np.mean(valloss),
157166
}
167+
| train_metrics.accumulate(str_prefix="Train ")
168+
| val_metrics.accumulate(str_prefix="Validation ")
158169
)
170+
train_metrics.reset()
171+
val_metrics.reset()
159172

160173
testloss = []
161174
model.eval()
@@ -167,11 +180,13 @@ def main():
167180
testloss.append(loss.item())
168181

169182
preds = th.argmax(logits, dim=1)
170-
metrics(y, preds)
183+
test_metrics(y, preds)
171184

172-
wandb.log(metrics.accumulate(str_prefix="Test "))
173-
metrics.reset()
174-
wandb.log({"Test loss": np.mean(testloss)})
185+
wandb.log(
186+
{"Epoch": 1, "Test loss": np.mean(testloss)}
187+
| test_metrics.accumulate(str_prefix="Test ")
188+
)
189+
test_metrics.reset()
175190

176191

177192
if __name__ == "__main__":

tests/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_f1score():
2626

2727
target = torch.tensor([0, 1, 0, 2])
2828

29-
f1_metric.update(preds, target)
29+
f1_metric(preds, target)
3030
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
3131
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
3232
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."

utils/arg_parser.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,6 @@ def get_args():
7373
action="store_true",
7474
help="If the flag is included, the metrics will be calculated using macro averaging.",
7575
)
76-
77-
78-
parser.add_argument("--imagesize", type=int, default=28, help="Imagesize")
79-
80-
parser.add_argument(
81-
"--nr_channels",
82-
type=int,
83-
default=1,
84-
choices=[1, 3],
85-
help="Number of image channels",
86-
)
8776

8877
# Training specific values
8978
parser.add_argument(

0 commit comments

Comments
 (0)