Skip to content

Commit bb38814

Browse files
committed
adding a kwargs argument to PytorchDataset should make it possible to use supplementary arguments
1 parent d176fe1 commit bb38814

File tree

1 file changed

+9
-31
lines changed

1 file changed

+9
-31
lines changed

continuum/datasets/base.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -66,40 +66,19 @@ class _SemanticSegmentationDataset(_ContinuumDataset):
6666
def data_type(self) -> str:
6767
return "segmentation"
6868

69-
class ContinuumDataset(_ContinuumDataset):
70-
"""Continuum version of torchvision datasets.
71-
72-
:param dataset: A Torchvision dataset, like MNIST or CIFAR100.
73-
74-
This class avoid to have to deal with specific parameters of some Pytorch dataset while creating them
75-
"""
76-
def __init__(
77-
self, dataset
78-
):
79-
super().__init__(data_path=dataset.root, train=dataset.train, download=False)
80-
self.dataset = dataset
81-
82-
def get_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
83-
x, y = np.array(self.dataset.data), np.array(self.dataset.targets)
84-
return x, y, None
85-
86-
8769
class PyTorchDataset(_ContinuumDataset):
8870
"""Continuum version of torchvision datasets.
89-
9071
:param dataset_type: A Torchvision dataset, like MNIST or CIFAR100.
9172
:param train: train flag
9273
:param download: download
9374
"""
9475

9576
# TODO: some datasets have a different structure, like SVHN for ex. Handle it.
9677
def __init__(
97-
self, data_path: str = "", dataset_type=None, train: bool = True, download: bool = True
98-
):
78+
self, data_path: str = "", dataset_type=None, train: bool = True, download: bool = True, **kwargs):
9979
super().__init__(data_path=data_path, train=train, download=download)
100-
10180
self.dataset_type = dataset_type
102-
self.dataset = self.dataset_type(self.data_path, download=self.download, train=self.train)
81+
self.dataset = self.dataset_type(self.data_path, download=self.download, train=self.train, **kwargs)
10382

10483
def get_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
10584
x, y = np.array(self.dataset.data), np.array(self.dataset.targets)
@@ -116,13 +95,13 @@ class InMemoryDataset(_ContinuumDataset):
11695
"""
11796

11897
def __init__(
119-
self,
120-
x: np.ndarray,
121-
y: np.ndarray,
122-
t: Union[None, np.ndarray] = None,
123-
data_type: str = "image_array",
124-
train: bool = True,
125-
download: bool = True,
98+
self,
99+
x: np.ndarray,
100+
y: np.ndarray,
101+
t: Union[None, np.ndarray] = None,
102+
data_type: str = "image_array",
103+
train: bool = True,
104+
download: bool = True,
126105
):
127106
super().__init__(train=train, download=download)
128107

@@ -160,7 +139,6 @@ def __init__(self, data_path: str, train: bool = True, download: bool = True, da
160139
self.data_path = data_path
161140
super().__init__(data_path=data_path, train=train, download=download)
162141

163-
164142
allowed_data_types = ("image_path", "segmentation")
165143
if data_type not in allowed_data_types:
166144
raise ValueError(f"Invalid data_type={data_type}, allowed={allowed_data_types}.")

0 commit comments

Comments
 (0)