Skip to content

Commit 5af2c61

Browse files
committed
fixed minor bugs discovered during testing the dataloader
1 parent 0043e11 commit 5af2c61

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

utils/dataloaders/mnist_0_3.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def __init__(self, data_path: Path, train: bool = False, transform=None, downloa
6262
self.train = train
6363
self.transform = transform
6464
self.download = download
65+
self.num_classes = 4
6566

67+
if not self.download and not self._chech_is_downloaded():
68+
raise ValueError("Data not found. Set --download-data=True to download the data.")
6669
if self.download and not self._chech_is_downloaded():
6770
self._download_data()
6871

@@ -121,8 +124,8 @@ def __getitem__(self, index):
121124
label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label
122125

123126
with open(self.images_path, "rb") as f:
124-
f.seek(16 + index * 28) # Jump to image position
125-
image = np.frombuffer(f.read(28), dtype=np.uint8).reshape(28, 28) # Read image data
127+
f.seek(16 + index * 28*28) # Jump to image position
128+
image = np.frombuffer(f.read(28*28), dtype=np.uint8).reshape(28, 28) # Read image data
126129

127130
if self.transform:
128131
image = self.transform(image)

0 commit comments

Comments
 (0)