@@ -66,40 +66,19 @@ class _SemanticSegmentationDataset(_ContinuumDataset):
6666 def data_type (self ) -> str :
6767 return "segmentation"
6868
69- class ContinuumDataset (_ContinuumDataset ):
70- """Continuum version of torchvision datasets.
71-
72- :param dataset: A Torchvision dataset, like MNIST or CIFAR100.
73-
74- This class avoid to have to deal with specific parameters of some Pytorch dataset while creating them
75- """
76- def __init__ (
77- self , dataset
78- ):
79- super ().__init__ (data_path = dataset .root , train = dataset .train , download = False )
80- self .dataset = dataset
81-
82- def get_data (self ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
83- x , y = np .array (self .dataset .data ), np .array (self .dataset .targets )
84- return x , y , None
85-
86-
8769class PyTorchDataset (_ContinuumDataset ):
8870 """Continuum version of torchvision datasets.
89-
9071 :param dataset_type: A Torchvision dataset, like MNIST or CIFAR100.
9172 :param train: train flag
9273 :param download: download
9374 """
9475
9576 # TODO: some datasets have a different structure, like SVHN for ex. Handle it.
9677 def __init__ (
97- self , data_path : str = "" , dataset_type = None , train : bool = True , download : bool = True
98- ):
78+ self , data_path : str = "" , dataset_type = None , train : bool = True , download : bool = True , ** kwargs ):
9979 super ().__init__ (data_path = data_path , train = train , download = download )
100-
10180 self .dataset_type = dataset_type
102- self .dataset = self .dataset_type (self .data_path , download = self .download , train = self .train )
81+ self .dataset = self .dataset_type (self .data_path , download = self .download , train = self .train , ** kwargs )
10382
10483 def get_data (self ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
10584 x , y = np .array (self .dataset .data ), np .array (self .dataset .targets )
@@ -116,13 +95,13 @@ class InMemoryDataset(_ContinuumDataset):
11695 """
11796
11897 def __init__ (
119- self ,
120- x : np .ndarray ,
121- y : np .ndarray ,
122- t : Union [None , np .ndarray ] = None ,
123- data_type : str = "image_array" ,
124- train : bool = True ,
125- download : bool = True ,
98+ self ,
99+ x : np .ndarray ,
100+ y : np .ndarray ,
101+ t : Union [None , np .ndarray ] = None ,
102+ data_type : str = "image_array" ,
103+ train : bool = True ,
104+ download : bool = True ,
126105 ):
127106 super ().__init__ (train = train , download = download )
128107
@@ -160,7 +139,6 @@ def __init__(self, data_path: str, train: bool = True, download: bool = True, da
160139 self .data_path = data_path
161140 super ().__init__ (data_path = data_path , train = train , download = download )
162141
163-
164142 allowed_data_types = ("image_path" , "segmentation" )
165143 if data_type not in allowed_data_types :
166144 raise ValueError (f"Invalid data_type={ data_type } , allowed={ allowed_data_types } ." )
0 commit comments