Skip to content

Commit bb73784

Browse files
committed
add batched pubchem dataset
1 parent 3ab691f commit bb73784

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

chebai/preprocessing/datasets/pubchem.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,78 @@ def _perform_data_preparation(self, *args, **kwargs):
212212
print("Done")
213213

214214

215+
class PubChemBatched(PubChem):
216+
"""Store train data as batches of 10m, validation and test should each be 100k max"""
217+
218+
def __init__(self, *args, **kwargs):
219+
super(PubChemBatched, self).__init__(*args, **kwargs)
220+
self.train_batch_size = 10_000_000
221+
if self.k != self.FULL:
222+
self.val_batch_size = (
223+
100_000
224+
if self.validation_split * self.k > 100_000
225+
else int(self.validation_split * self.k)
226+
)
227+
self.test_batch_size = (
228+
100_000
229+
if self.test_split * self.k > 100_000
230+
else int(self.test_split * self.k)
231+
)
232+
else:
233+
self.val_batch_size = 100_000
234+
self.test_batch_size = 100_000
235+
236+
@property
237+
def processed_file_names(self) -> List[str]:
238+
"""
239+
Returns:
240+
List[str]: List of processed data file names.
241+
"""
242+
train_samples = (
243+
self._k if self._k != self.FULL else 120_000_000
244+
) # estimate size
245+
train_samples -= self.val_batch_size + self.test_batch_size
246+
train_batches = (
247+
["train.pt"]
248+
if train_samples <= self.train_batch_size
249+
else [
250+
f"train_{i}.pt"
251+
for i in range((train_samples // self.train_batch_size) + 1)
252+
]
253+
)
254+
return train_batches + ["test.pt", "validation.pt"]
255+
256+
def setup_processed(self):
257+
"""
258+
Prepares processed data and saves them as Torch tensors.
259+
"""
260+
filename = os.path.join(self.raw_dir, self.raw_file_names[0])
261+
print("Load data from file", filename)
262+
data = self._load_data_from_file(filename)
263+
print("Create splits")
264+
train, test = train_test_split(
265+
data, test_size=self.test_batch_size + self.val_batch_size
266+
)
267+
del data
268+
test, val = train_test_split(test, train_size=self.test_batch_size)
269+
torch.save(test, os.path.join(self.processed_dir, "test.pt"))
270+
torch.save(val, os.path.join(self.processed_dir, "validation.pt"))
271+
272+
# batch training if necessary
273+
if len(train) > self.train_batch_size:
274+
train_batches = [
275+
train[i : i + self.train_batch_size]
276+
for i in range(0, len(train), self.train_batch_size)
277+
]
278+
train = [torch.tensor(batch) for batch in train_batches]
279+
for i, batch in enumerate(train):
280+
torch.save(batch, os.path.join(self.processed_dir, f"train_{i}.pt"))
281+
else:
282+
torch.save(train, os.path.join(self.processed_dir, "train.pt"))
283+
284+
self.reader.on_finish()
285+
286+
215287
class PubChemDissimilar(PubChem):
216288
"""
217289
Subset of PubChem, but choosing the most dissimilar molecules (according to fingerprint)

0 commit comments

Comments
 (0)