44from PIL import Image
55from torch .utils .data import Dataset
66from torchvision import transforms
7+ from pathlib import Path
78
89
910class USPSH5_Digit_7_9_Dataset (Dataset ):
@@ -30,7 +31,7 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
3031 A transform function to apply to the images.
3132 """
3233
33- def __init__ (self , h5_path , mode , transform = None ):
34+ def __init__ (self , data_path , train = False , transform = None ):
3435 super ().__init__ ()
3536 """
3637 Initializes the USPS dataset by loading images and labels from the given `.h5` file.
@@ -43,12 +44,13 @@ def __init__(self, h5_path, mode, transform=None):
4344 transform : callable, optional, default=None
4445 A transform function to apply on images.
4546 """
46-
47+ self .filename = "usps.h5"
48+ path = data_path if isinstance (data_path , Path ) else Path (data_path )
49+ self .filepath = path / self .filename
4750 self .transform = transform
48- self .mode = mode
49- self .h5_path = h5_path
51+ self .mode = "train" if train else "test"
5052 # Load the dataset from the HDF5 file
51- with h5py .File (self .h5_path , "r" ) as hf :
53+ with h5py .File (self .filepath , "r" ) as hf :
5254 images = hf [self .mode ]["data" ][:]
5355 labels = hf [self .mode ]["target" ][:]
5456
@@ -105,8 +107,8 @@ def main():
105107
106108 # Load the dataset
107109 dataset = USPSH5_Digit_7_9_Dataset (
108- h5_path = "C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5 " ,
109- mode = " train" ,
110+ data_path = "C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git" ,
111+ train = False ,
110112 transform = transform ,
111113 )
112114 data_loader = torch .utils .data .DataLoader (dataset , batch_size = 2 , shuffle = True )
0 commit comments