Skip to content

Commit 8d6c07a

Browse files
committed
Merge branch 'johan/dataloader' into johan/micromacro
2 parents a7d51c4 + ff32432 commit 8d6c07a

File tree

4 files changed

+2375
-0
lines changed

4 files changed

+2375
-0
lines changed

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.12

pyproject.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
1+
[project]
2+
name = "collaborative-coding-exam"
3+
version = "0.1.0"
4+
description = "Exam project in the collaborative coding course."
5+
readme = "README.md"
6+
requires-python = ">=3.12"
7+
dependencies = [
8+
"black>=25.1.0",
9+
"h5py>=3.12.1",
10+
"isort>=6.0.0",
11+
"jupyterlab>=4.3.5",
12+
"numpy>=2.2.2",
13+
"pandas>=2.2.3",
14+
"pip>=25.0",
15+
"pytest>=8.3.4",
16+
"ruff>=0.9.4",
17+
"scalene>=1.5.51",
18+
"sphinx>=8.1.3",
19+
"sphinx-autoapi>=3.4.0",
20+
"sphinx-autobuild>=2024.10.3",
21+
"sphinx-rtd-theme>=3.0.2",
22+
"torch>=2.6.0",
23+
"torchvision>=0.21.0",
24+
]
125
[tool.isort]
226
profile = "black"
327
line_length = 88

utils/dataloaders/mnist_4_9.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import gzip
2+
import os
3+
import urllib.request as ur
4+
from pathlib import Path
5+
import numpy as np
6+
from torch.utils.data import Dataset
7+
8+
class MNIST_4_9(Dataset):
9+
def __init__(self,
10+
datapath: Path,
11+
train: bool = False,
12+
download: bool = False
13+
):
14+
super.__init__()
15+
self.datapath = datapath
16+
self.mnist_path = self.datapath / "MNIST"
17+
self.train = train
18+
self.download = download
19+
self.num_classes: int = 6
20+
21+
if not self.download and not self._already_downloaded():
22+
raise FileNotFoundError(
23+
'Data files are not found. Set --download-data=True to download the data'
24+
)
25+
if self.download and not self._already_downloaded():
26+
self._download()
27+
28+
29+
30+
31+
def _download(self):
32+
urls: dict = {
33+
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
34+
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
35+
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
36+
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
37+
}
38+
39+
40+
for url in urls.values():
41+
file_path: Path = os.path.join(self.mnist_path, url.split('/')[-1])
42+
file_name: Path = file_path.replace('.gz','')
43+
if os.path.exists(file_name):
44+
print(f"File: {file_name} already downloaded")
45+
else:
46+
print(f"File: {file_name} is downloading...")
47+
ur.urlretrieve(url, file_path) # Download file
48+
with gzip.open(file_path, 'rb') as infile:
49+
with open(file_name, 'wb') as outfile:
50+
outfile.write(infile.read()) # Write from url to local file
51+
os.remove(file_path) # remove .gz file
52+
53+
54+
55+
def _already_downloaded(self):
56+
if self.mnist_path.exists():
57+
required_files: list = [
58+
"train-images-idx3-ubyte",
59+
"train-labels-idx1-ubyte",
60+
"t10k-images-idx3-ubyte",
61+
"t10k-labels-idx1-ubyte",
62+
]
63+
return all([(self.mnist_path / file).exists() for file in required_files])
64+
65+
else:
66+
self.mnist_path.mkdir(parents=True, exist_ok=True)
67+
return False
68+
69+
def __len__(self):
70+
pass
71+
72+
def __getitem__(self):
73+
pass
74+
75+

0 commit comments

Comments
 (0)