Skip to content

Commit e3297c6

Browse files
committed
add dummy dataset class - quick testing purpose
1 parent ff03d6f commit e3297c6

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# This file is for developers only
2+
3+
__all__ = [] # Nothing should be imported from this file
4+
5+
6+
import random
7+
8+
import numpy as np
9+
from torch.utils.data import DataLoader, Dataset
10+
11+
from chebai.preprocessing.datasets import XYBaseDataModule
12+
from chebai.preprocessing.reader import ChemDataReader
13+
14+
15+
class _DummyDataModule(XYBaseDataModule):
16+
17+
READER = ChemDataReader
18+
19+
def __init__(self, num_of_labels: int, feature_vector_size: int, *args, **kwargs):
20+
super().__init__(*args, **kwargs)
21+
self._num_of_labels = num_of_labels
22+
self._feature_vector_size = feature_vector_size
23+
assert self._num_of_labels is not None
24+
assert self._feature_vector_size is not None
25+
26+
def prepare_data(self):
27+
pass
28+
29+
def setup(self, stage=None):
30+
pass
31+
32+
@property
33+
def num_of_labels(self):
34+
return self._num_of_labels
35+
36+
@property
37+
def feature_vector_size(self):
38+
return self._feature_vector_size
39+
40+
def train_dataloader(self, *args, **kwargs) -> DataLoader:
41+
dataset = _DummyDataset(100, self.num_of_labels, self.feature_vector_size)
42+
return DataLoader(
43+
dataset,
44+
collate_fn=self.reader.collator,
45+
batch_size=self.batch_size,
46+
**kwargs,
47+
)
48+
49+
def test_dataloader(self, *args, **kwargs) -> DataLoader:
50+
dataset = _DummyDataset(20, self.num_of_labels, self.feature_vector_size)
51+
return DataLoader(
52+
dataset,
53+
collate_fn=self.reader.collator,
54+
batch_size=self.batch_size,
55+
**kwargs,
56+
)
57+
58+
def val_dataloader(self, *args, **kwargs) -> DataLoader:
59+
dataset = _DummyDataset(10, self.num_of_labels, self.feature_vector_size)
60+
return DataLoader(
61+
dataset,
62+
collate_fn=self.reader.collator,
63+
batch_size=self.batch_size,
64+
**kwargs,
65+
)
66+
67+
@property
68+
def _name(self) -> str:
69+
return "_DummyDataModule"
70+
71+
72+
class _DummyDataset(Dataset):
73+
def __init__(self, num_samples: int, num_labels: int, feature_vector_size: int):
74+
self.num_samples = num_samples
75+
self.num_labels = num_labels
76+
self.feature_vector_size = feature_vector_size
77+
78+
def __len__(self):
79+
return self.num_samples
80+
81+
def __getitem__(self, idx):
82+
return {
83+
"features": np.random.randint(
84+
10, 100, size=self.feature_vector_size
85+
), # Random feature vector
86+
"labels": np.random.choice(
87+
[False, True], size=self.num_labels
88+
), # Random boolean labels
89+
"ident": random.randint(1, 40000), # Random identifier
90+
"group": None, # Default group value
91+
}
92+
93+
94+
if __name__ == "__main__":
95+
dataset = _DummyDataset(num_samples=100, num_labels=5, feature_vector_size=20)
96+
for i in range(10):
97+
print(dataset[i])

configs/data/_dummy.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
class_path: chebai.preprocessing.datasets._dummy._DummyDataModule
2+
init_args:
3+
feature_vector_size: 20

0 commit comments

Comments
 (0)