Skip to content

Commit 2341c69

Browse files
committed
Formatting
1 parent 2933536 commit 2341c69

File tree

4 files changed

+4
-7
lines changed

4 files changed

+4
-7
lines changed

main.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import argparse
21
from pathlib import Path
32

43
import numpy as np
@@ -9,7 +8,7 @@
98
from torchvision import transforms
109
from tqdm import tqdm
1110

12-
from utils import MetricWrapper, createfolders, load_data, load_model, get_args
11+
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1312

1413

1514
def main():
@@ -31,7 +30,7 @@ def main():
3130

3231
device = args.device
3332

34-
if args.dataset.lower() == "usps_0-6" or args.dataset.lower() == "uspsh5_7_9":
33+
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
3534
augmentations = transforms.Compose(
3635
[
3736
transforms.Resize((16, 16)),

tests/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,6 @@ def test_accuracy():
9494

9595
accuracy_score = accuracy(y_true, y_pred)
9696

97-
assert (torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5), (
97+
assert torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5, (
9898
f"Accuracy Score: {accuracy_score.item()}"
9999
)

utils/dataloaders/mnist_0_3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,3 @@ def __getitem__(self, index):
149149
image = self.transform(image)
150150

151151
return image, label
152-
153-

utils/metrics/accuracy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ class Accuracy(nn.Module):
66
def __init__(self, num_classes):
77
super().__init__()
88
self.num_classes = num_classes
9-
9+
1010
def forward(self, y_true, y_pred):
1111
"""
1212
Compute the accuracy of the model.

0 commit comments

Comments
 (0)