1+ from pathlib import Path
2+
13import h5py
24import numpy as np
35import torch
@@ -30,7 +32,7 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
3032 A transform function to apply to the images.
3133 """
3234
33- def __init__ (self , h5_path , mode , transform = None ):
35+ def __init__ (self , data_path , train = False , transform = None ):
3436 super ().__init__ ()
3537 """
3638 Initializes the USPS dataset by loading images and labels from the given `.h5` file.
@@ -43,12 +45,13 @@ def __init__(self, h5_path, mode, transform=None):
4345 transform : callable, optional, default=None
4446 A transform function to apply on images.
4547 """
46-
48+ self .filename = "usps.h5"
49+ path = data_path if isinstance (data_path , Path ) else Path (data_path )
50+ self .filepath = path / self .filename
4751 self .transform = transform
48- self .mode = mode
49- self .h5_path = h5_path
52+ self .mode = "train" if train else "test"
5053 # Load the dataset from the HDF5 file
51- with h5py .File (self .h5_path , "r" ) as hf :
54+ with h5py .File (self .filepath , "r" ) as hf :
5255 images = hf [self .mode ]["data" ][:]
5356 labels = hf [self .mode ]["target" ][:]
5457
@@ -105,8 +108,8 @@ def main():
105108
106109 # Load the dataset
107110 dataset = USPSH5_Digit_7_9_Dataset (
108- h5_path = "C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5 " ,
109- mode = " train" ,
111+ data_path = "C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git" ,
112+ train = False ,
110113 transform = transform ,
111114 )
112115 data_loader = torch .utils .data .DataLoader (dataset , batch_size = 2 , shuffle = True )
0 commit comments