Skip to content

Commit 1285d36

Browse files
committed
ran ruff and isort
1 parent ecb6db4 commit 1285d36

File tree

6 files changed

+69
-52
lines changed

6 files changed

+69
-52
lines changed

utils/dataloaders/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__all__ = ["USPSDataset0_6","MNISTDataset0_3"]
1+
__all__ = ["USPSDataset0_6", "MNISTDataset0_3"]
22

3+
from .mnist_0_3 import MNISTDataset0_3
34
from .usps_0_6 import USPSDataset0_6
4-
from .mnist_0_3 import MNISTDataset0_3

utils/dataloaders/mnist_0_3.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from pathlib import Path
2-
3-
from torch.utils.data import Dataset
4-
import numpy as np
5-
import urllib.request
61
import gzip
72
import os
3+
import urllib.request
4+
from pathlib import Path
85

6+
import numpy as np
7+
from torch.utils.data import Dataset
98

109

1110
class MNISTDataset0_3(Dataset):
@@ -54,39 +53,56 @@ class MNISTDataset0_3(Dataset):
5453
__getitem__(index)
5554
Returns the image and label at the specified index.
5655
"""
57-
def __init__(self, data_path: Path, train: bool = False, transform=None, download: bool = False,):
56+
57+
def __init__(
58+
self,
59+
data_path: Path,
60+
train: bool = False,
61+
transform=None,
62+
download: bool = False,
63+
):
5864
super().__init__()
59-
65+
6066
self.data_path = data_path
6167
self.mnist_path = self.data_path / "MNIST"
6268
self.train = train
6369
self.transform = transform
6470
self.download = download
6571
self.num_classes = 4
66-
72+
6773
if not self.download and not self._chech_is_downloaded():
68-
raise ValueError("Data not found. Set --download-data=True to download the data.")
74+
raise ValueError(
75+
"Data not found. Set --download-data=True to download the data."
76+
)
6977
if self.download and not self._chech_is_downloaded():
7078
self._download_data()
71-
72-
self.images_path = self.mnist_path / ("train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte")
73-
self.labels_path = self.mnist_path / ("train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte")
74-
79+
80+
self.images_path = self.mnist_path / (
81+
"train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte"
82+
)
83+
self.labels_path = self.mnist_path / (
84+
"train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte"
85+
)
86+
7587
labels = self._parse_labels(train=self.train)
76-
77-
self.idx = np.where(labels < 4)[0]
78-
88+
89+
self.idx = np.where(labels < 4)[0]
90+
7991
self.length = len(self.idx)
80-
81-
92+
8293
def _parse_labels(self, train):
8394
with open(self.labels_path, "rb") as f:
8495
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
8596
return data
86-
97+
8798
def _chech_is_downloaded(self):
8899
if self.mnist_path.exists():
89-
required_files = ["train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"]
100+
required_files = [
101+
"train-images-idx3-ubyte",
102+
"train-labels-idx1-ubyte",
103+
"t10k-images-idx3-ubyte",
104+
"t10k-labels-idx1-ubyte",
105+
]
90106
if all([(self.mnist_path / file).exists() for file in required_files]):
91107
print("MNIST Dataset already downloaded.")
92108
return True
@@ -95,26 +111,24 @@ def _chech_is_downloaded(self):
95111
else:
96112
self.mnist_path.mkdir(parents=True, exist_ok=True)
97113
return False
98-
99-
114+
100115
def _download_data(self):
101116
urls = {
102-
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
103-
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
104-
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
105-
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
106-
}
107-
117+
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
118+
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
119+
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
120+
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
121+
}
122+
108123
for name, url in urls.items():
109124
file_path = os.path.join(self.mnist_path, url.split("/")[-1])
110125
if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading
111126
urllib.request.urlretrieve(url, file_path)
112-
with gzip.open(file_path, 'rb') as f_in:
113-
with open(file_path.replace(".gz", ""), 'wb') as f_out:
127+
with gzip.open(file_path, "rb") as f_in:
128+
with open(file_path.replace(".gz", ""), "wb") as f_out:
114129
f_out.write(f_in.read())
115130
os.remove(file_path) # Remove compressed file
116131

117-
118132
def __len__(self):
119133
return self.length
120134

@@ -124,12 +138,14 @@ def __getitem__(self, index):
124138
label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label
125139

126140
with open(self.images_path, "rb") as f:
127-
f.seek(16 + index * 28*28) # Jump to image position
128-
image = np.frombuffer(f.read(28*28), dtype=np.uint8).reshape(28, 28) # Read image data
129-
130-
image = np.expand_dims(image, axis=0) # Add channel dimension
131-
141+
f.seek(16 + index * 28 * 28) # Jump to image position
142+
image = np.frombuffer(f.read(28 * 28), dtype=np.uint8).reshape(
143+
28, 28
144+
) # Read image data
145+
146+
image = np.expand_dims(image, axis=0) # Add channel dimension
147+
132148
if self.transform:
133149
image = self.transform(image)
134-
135-
return image, label
150+
151+
return image, label

utils/load_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torch.utils.data import Dataset
22

3-
from .dataloaders import USPSDataset0_6, MNISTDataset0_3
3+
from .dataloaders import MNISTDataset0_3, USPSDataset0_6
44

55

66
def load_data(dataset: str, *args, **kwargs) -> Dataset:

utils/load_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22

3-
from .models import ChristianModel, MagnusModel, JanModel
3+
from .models import ChristianModel, JanModel, MagnusModel
44

55

66
def load_model(modelname: str, *args, **kwargs) -> nn.Module:

utils/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
__all__ = ["MagnusModel", "ChristianModel", "JanModel"]
22

33
from .christian_model import ChristianModel
4-
from .magnus_model import MagnusModel
54
from .jan_model import JanModel
5+
from .magnus_model import MagnusModel

utils/models/jan_model.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
23
"""
34
A simple neural network model for classification tasks.
45
Parameters
@@ -31,7 +32,6 @@
3132
import torch.nn as nn
3233

3334

34-
3535
class JanModel(nn.Module):
3636
"""A simple MLP network model for image classification tasks.
3737
@@ -59,22 +59,23 @@ class JanModel(nn.Module):
5959
fc2 Output Shape: (5, 100)
6060
out Output Shape: (5, num_classes)
6161
"""
62+
6263
def __init__(self, image_shape, num_classes):
6364
super().__init__()
64-
65+
6566
self.in_channels = image_shape[0]
6667
self.height = image_shape[1]
6768
self.width = image_shape[2]
6869
self.num_classes = num_classes
69-
70+
7071
self.fc1 = nn.Linear(self.height * self.width * self.in_channels, 100)
71-
72+
7273
self.fc2 = nn.Linear(100, 100)
73-
74+
7475
self.out = nn.Linear(100, num_classes)
75-
76+
7677
self.leaky_relu = nn.LeakyReLU()
77-
78+
7879
self.flatten = nn.Flatten()
7980

8081
def forward(self, x):
@@ -85,8 +86,8 @@ def forward(self, x):
8586
x = self.leaky_relu(x)
8687
x = self.out(x)
8788
return x
88-
89-
89+
90+
9091
if __name__ == "__main__":
9192
model = JanModel(2, 4)
9293

0 commit comments

Comments
 (0)