Skip to content

Commit a4214d2

Browse files
committed
Had to modify to fit in the overall format
1 parent 744699f commit a4214d2

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

utils/dataloaders/uspsh5_7_9.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
3030
A transform function to apply to the images.
3131
"""
3232

33-
def __init__(self, h5_path, mode, transform=None):
33+
filename = "usps.h5"
34+
35+
def __init__(self, data_path, train=False, transform=None, download=False):
3436
super().__init__()
3537
"""
3638
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
@@ -45,8 +47,8 @@ def __init__(self, h5_path, mode, transform=None):
4547
"""
4648

4749
self.transform = transform
48-
self.mode = mode
49-
self.h5_path = h5_path
50+
self.mode = "train" if train else "test"
51+
self.h5_path = data_path / self.filename
5052
# Load the dataset from the HDF5 file
5153
with h5py.File(self.h5_path, "r") as hf:
5254
images = hf[self.mode]["data"][:]

0 commit comments

Comments
 (0)