Skip to content

Commit aca9c56

Browse files
Merge pull request #127 from Continvvm/continuumDataset
Add kwargs to pytorch datasets
2 parents 16deee9 + 35625a9 commit aca9c56

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

continuum/datasets/base.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,19 @@ class _SemanticSegmentationDataset(_ContinuumDataset):
6666
def data_type(self) -> str:
6767
return "segmentation"
6868

69-
7069
class PyTorchDataset(_ContinuumDataset):
7170
"""Continuum version of torchvision datasets.
72-
7371
:param dataset_type: A Torchvision dataset, like MNIST or CIFAR100.
72+
:param train: train flag
73+
:param download: download
7474
"""
7575

7676
# TODO: some datasets have a different structure, like SVHN for ex. Handle it.
7777
def __init__(
78-
self, data_path: str = "", dataset_type=None, train: bool = True, download: bool = True
79-
):
78+
self, data_path: str = "", dataset_type=None, train: bool = True, download: bool = True, **kwargs):
8079
super().__init__(data_path=data_path, train=train, download=download)
81-
8280
self.dataset_type = dataset_type
83-
self.dataset = self.dataset_type(self.data_path, download=self.download, train=self.train)
81+
self.dataset = self.dataset_type(self.data_path, download=self.download, train=self.train, **kwargs)
8482

8583
def get_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
8684
x, y = np.array(self.dataset.data), np.array(self.dataset.targets)
@@ -97,13 +95,13 @@ class InMemoryDataset(_ContinuumDataset):
9795
"""
9896

9997
def __init__(
100-
self,
101-
x: np.ndarray,
102-
y: np.ndarray,
103-
t: Union[None, np.ndarray] = None,
104-
data_type: str = "image_array",
105-
train: bool = True,
106-
download: bool = True,
98+
self,
99+
x: np.ndarray,
100+
y: np.ndarray,
101+
t: Union[None, np.ndarray] = None,
102+
data_type: str = "image_array",
103+
train: bool = True,
104+
download: bool = True,
107105
):
108106
super().__init__(train=train, download=download)
109107

@@ -141,7 +139,6 @@ def __init__(self, data_path: str, train: bool = True, download: bool = True, da
141139
self.data_path = data_path
142140
super().__init__(data_path=data_path, train=train, download=download)
143141

144-
145142
allowed_data_types = ("image_path", "segmentation")
146143
if data_type not in allowed_data_types:
147144
raise ValueError(f"Invalid data_type={data_type}, allowed={allowed_data_types}.")

tests/test_dataset_attributes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pytest
22

33
from continuum import datasets as cont_datasets
4+
from torchvision.datasets import EMNIST, KMNIST
5+
from continuum.datasets import PyTorchDataset
46

57
ATTRS = ["get_data", "_download"]
68

@@ -11,3 +13,13 @@ def test_has_attr(dataset_name):
1113

1214
for attr in ATTRS:
1315
assert hasattr(d, attr), (dataset_name, attr)
16+
17+
18+
@pytest.mark.slow
19+
def test_PytorchDataset_EMNIST(tmpdir):
20+
dataset_train = PyTorchDataset(tmpdir, dataset_type=EMNIST, train=True, download=True, split='letters')
21+
22+
23+
@pytest.mark.slow
24+
def test_PytorchDataset_KMNIST(tmpdir):
25+
dataset_train = PyTorchDataset(tmpdir, dataset_type=KMNIST, train=True, download=True)

0 commit comments

Comments
 (0)