-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
60 lines (48 loc) · 2.43 KB
/
dataset.py
File metadata and controls
60 lines (48 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
from mnist import MNIST # pip install python-mnist
import logging
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict
def build_dataset_and_loader(batch_size: int, partition: str, logger: logging.Logger, data_dir="./data/"):
assert partition in ('training', 'testing'), f"{partition} is an invalid partition"
dataset = MNISTDataset(partition, data_dir, logger)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(partition == 'training'), num_workers=2,
collate_fn=dataset.collate_fnc)
return dataset, dataloader
class MNISTDataset(Dataset):
imsize = 28
def __init__(self, partition: str, mnist_dir: str, logger: logging.Logger):
assert partition in ('training', 'testing'), f"{partition} is an invalid partition"
mnist = MNIST(mnist_dir)
raw_dataset_parser = getattr(mnist, f"load_{partition}")
self.images, self.labels = raw_dataset_parser()
self.nimages = len(self.images)
logger.info(f"Loaded {self.nimages} images for {partition}")
def __len__(self):
return self.nimages
def __getitem__(self, index: int) -> Dict:
"""
Returns a single data point (image-label pair) from the dataset.
:param index: Index of the data point to be accessed
:return: {
'image': np.ndarray, shape (28 * 28), NOTE: images loaded from raw_dataset is flattened into a 1-d array
'label': int
}
"""
img = np.array(self.images[index], dtype=np.float32) / 255.0 # Normalize pixel values to [0,1]
label = int(self.labels[index]) # Ensure label is an integer
return {'image': img, 'label': label}
@staticmethod
def collate_fnc(data_batch: List[Dict]) -> Dict:
"""
Converts a batch of individual data points into a mini-batch.
:param data_batch: A List of N dicts, each containing:
'image': np.ndarray, shape (28 * 28),
'label': int
:return: A single dictionary containing:
'image': np.ndarray, shape (N, 28 * 28),
'label': np.ndarray, shape (N)
"""
images = np.array([data['image'] for data in data_batch], dtype=np.float32) # Stack images into (N, 784)
labels = np.array([data['label'] for data in data_batch], dtype=np.int64) # Stack labels into (N,)
return {'image': images, 'label': labels}