@@ -66,21 +66,19 @@ class _SemanticSegmentationDataset(_ContinuumDataset):
6666 def data_type (self ) -> str :
6767 return "segmentation"
6868
69-
7069class PyTorchDataset (_ContinuumDataset ):
7170 """Continuum version of torchvision datasets.
72-
7371 :param dataset_type: A Torchvision dataset, like MNIST or CIFAR100.
72+ :param train: train flag
73+ :param download: download
7474 """
7575
7676 # TODO: some datasets have a different structure, like SVHN for ex. Handle it.
7777 def __init__ (
78- self , data_path : str = "" , dataset_type = None , train : bool = True , download : bool = True
79- ):
78+ self , data_path : str = "" , dataset_type = None , train : bool = True , download : bool = True , ** kwargs ):
8079 super ().__init__ (data_path = data_path , train = train , download = download )
81-
8280 self .dataset_type = dataset_type
83- self .dataset = self .dataset_type (self .data_path , download = self .download , train = self .train )
81+ self .dataset = self .dataset_type (self .data_path , download = self .download , train = self .train , ** kwargs )
8482
8583 def get_data (self ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
8684 x , y = np .array (self .dataset .data ), np .array (self .dataset .targets )
@@ -97,13 +95,13 @@ class InMemoryDataset(_ContinuumDataset):
9795 """
9896
9997 def __init__ (
100- self ,
101- x : np .ndarray ,
102- y : np .ndarray ,
103- t : Union [None , np .ndarray ] = None ,
104- data_type : str = "image_array" ,
105- train : bool = True ,
106- download : bool = True ,
98+ self ,
99+ x : np .ndarray ,
100+ y : np .ndarray ,
101+ t : Union [None , np .ndarray ] = None ,
102+ data_type : str = "image_array" ,
103+ train : bool = True ,
104+ download : bool = True ,
107105 ):
108106 super ().__init__ (train = train , download = download )
109107
@@ -141,7 +139,6 @@ def __init__(self, data_path: str, train: bool = True, download: bool = True, da
141139 self .data_path = data_path
142140 super ().__init__ (data_path = data_path , train = train , download = download )
143141
144-
145142 allowed_data_types = ("image_path" , "segmentation" )
146143 if data_type not in allowed_data_types :
147144 raise ValueError (f"Invalid data_type={ data_type } , allowed={ allowed_data_types } ." )
0 commit comments