Skip to content

Commit 2933536

Browse files
committed
fixed test
1 parent 38d2499 commit 2933536

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

main.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def main():
3131

3232
device = args.device
3333

34-
3534
if args.dataset.lower() == "usps_0-6" or args.dataset.lower() == "uspsh5_7_9":
3635
augmentations = transforms.Compose(
3736
[
@@ -58,8 +57,8 @@ def main():
5857
transform=augmentations,
5958
)
6059

61-
metrics = MetricWrapper(*args.metric, num_classes = traindata.num_classes)
62-
60+
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
61+
6362
# Find the shape of the data, if is 2D, add a channel dimension
6463
data_shape = traindata[0][0].shape
6564
if len(data_shape) == 2:
@@ -106,18 +105,17 @@ def main():
106105

107106
optimizer.step()
108107
optimizer.zero_grad(set_to_none=True)
109-
108+
110109
preds = th.argmax(logits, dim=1)
111110
metrics(y, preds)
112111

113-
114112
break
115113
print(metrics.__getmetrics__())
116114
print("Dry run completed successfully.")
117115
exit(0)
118116

119117
wandb.login(key=WANDB_API)
120-
wandb.init(entity="ColabCode",project="Jan", tags=[args.modelname, args.dataset])
118+
wandb.init(entity="ColabCode", project="Jan", tags=[args.modelname, args.dataset])
121119
wandb.watch(model)
122120

123121
for epoch in range(args.epoch):
@@ -134,10 +132,10 @@ def main():
134132
optimizer.step()
135133
optimizer.zero_grad(set_to_none=True)
136134
trainingloss.append(loss.item())
137-
135+
138136
preds = th.argmax(logits, dim=1)
139137
metrics(y, preds)
140-
138+
141139
wandb.log(metrics.__getmetrics__(str_prefix="Train "))
142140
metrics.__resetvalues__()
143141

@@ -150,10 +148,10 @@ def main():
150148
logits = model.forward(x)
151149
loss = criterion(logits, y)
152150
evalloss.append(loss.item())
153-
151+
154152
preds = th.argmax(logits, dim=1)
155153
metrics(y, preds)
156-
154+
157155
wandb.log(metrics.__getmetrics__(str_prefix="Evaluation "))
158156
metrics.__resetvalues__()
159157

tests/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_for_zero_denominator():
8787
def test_accuracy():
8888
import torch
8989

90-
accuracy = Accuracy()
90+
accuracy = Accuracy(num_classes=5)
9191

9292
y_true = torch.tensor([0, 3, 2, 3, 4])
9393
y_pred = torch.tensor([0, 1, 2, 3, 4])

0 commit comments

Comments
 (0)