Skip to content

Commit 693a3a7

Browse files
authored
Added ComposeDataset implementation to support dataset composition. (#104) (#105)
1 parent 80038df commit 693a3a7

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

mipcandy/data/dataset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,29 @@ def construct_new(self, images: UnsupervisedDataset, labels: UnsupervisedDataset
134134
return MergedDataset(DatasetFromMemory(images), DatasetFromMemory(labels), device=self._device)
135135

136136

137+
class ComposeDataset(_AbstractDataset[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]):
138+
def __init__(self, bases: Sequence[SupervisedDataset] | Sequence[UnsupervisedDataset], *,
139+
device: Device = "cpu") -> None:
140+
super().__init__(device)
141+
self._bases: dict[tuple[int, int], SupervisedDataset | UnsupervisedDataset] = {}
142+
self._len = 0
143+
for dataset in bases:
144+
end = len(dataset)
145+
self._bases[(self._len, self._len + end)] = dataset
146+
self._len += end
147+
148+
@override
149+
def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
150+
for (start, end), base in self._bases.items():
151+
if start <= idx < end:
152+
return base.load(idx - start)
153+
raise IndexError(f"Index {idx} out of range [0, {self._len})")
154+
155+
@override
156+
def __len__(self) -> int:
157+
return self._len
158+
159+
137160
class PathBasedUnsupervisedDataset(UnsupervisedDataset[list[str]], metaclass=ABCMeta):
138161
def paths(self) -> list[str]:
139162
return self._images

0 commit comments

Comments
 (0)