Skip to content

Commit 8c0454c

Browse files
committed
fix batched pubchem
1 parent 3b233d6 commit 8c0454c

File tree

1 file changed

+48
-12
lines changed

1 file changed

+48
-12
lines changed

chebai/preprocessing/datasets/pubchem.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -262,41 +262,77 @@ def processed_file_names_dict(self) -> List[str]:
262262
train_batches["validation"] = "validation.pt"
263263
return train_batches
264264

265+
def _tokenize_batched(self, data):
266+
"""
267+
Load data from a file and return a list of dictionaries, batched in 1,000,000 entries.
268+
269+
Args:
270+
path (str): The path to the input file.
271+
batch_size (int): The size of each batch.
272+
batch_idx (int): The index of the batch to load.
273+
274+
Returns:
275+
List: A list of dictionaries containing the features and labels.
276+
"""
277+
print(f"Processing {len(data)} lines...")
278+
batch = []
279+
for i, d in enumerate(tqdm.tqdm(data, total=len(data))):
280+
if d["features"] is not None:
281+
batch.append(self.reader.to_data(d))
282+
if i % 1_000_000 == 0 and i > 0:
283+
print(f"Saving batch {i // 1_000_000}")
284+
batch = [
285+
b
286+
for b in batch
287+
if b["features"] is not None
288+
and self.n_token_limit is None
289+
or len(b["features"]) <= self.n_token_limit
290+
]
291+
yield batch
292+
batch = []
293+
print("Saving final batch")
294+
batch = [
295+
b
296+
for b in batch
297+
if b["features"] is not None
298+
and self.n_token_limit is None
299+
or len(b["features"]) <= self.n_token_limit
300+
]
301+
yield batch
302+
265303
def setup_processed(self):
266304
"""
267305
Prepares processed data and saves them as Torch tensors.
268306
"""
269307
filename = os.path.join(self.raw_dir, self.raw_file_names[0])
270308
print("Load data from file", filename)
271-
data = self._load_data_from_file(filename)
309+
data_not_tokenized = [entry for entry in self._load_dict(filename)]
272310
print("Create splits")
273311
train, test = train_test_split(
274-
data, test_size=self.test_batch_size + self.val_batch_size
312+
data_not_tokenized, test_size=self.test_batch_size + self.val_batch_size
275313
)
276-
del data
314+
del data_not_tokenized
277315
test, val = train_test_split(test, train_size=self.test_batch_size)
278316
torch.save(
279-
test,
317+
self._tokenize_batched(test),
280318
os.path.join(self.processed_dir, self.processed_file_names_dict["test"]),
281319
)
282320
torch.save(
283-
val,
321+
self._tokenize_batched(val),
284322
os.path.join(
285323
self.processed_dir, self.processed_file_names_dict["validation"]
286324
),
287325
)
288326

289327
# batch training if necessary
290328
if len(train) > self.train_batch_size:
291-
train_batches = [
292-
train[i : i + self.train_batch_size]
293-
for i in range(0, len(train), self.train_batch_size)
294-
]
295-
train = [torch.tensor(batch) for batch in train_batches]
296-
for i, batch in enumerate(train):
329+
for i, batch in enumerate(self._tokenize_batched(train)):
297330
torch.save(batch, os.path.join(self.processed_dir, f"train_{i}.pt"))
298331
else:
299-
torch.save(train, os.path.join(self.processed_dir, "train.pt"))
332+
torch.save(
333+
self._tokenize_batched(train),
334+
os.path.join(self.processed_dir, "train.pt"),
335+
)
300336

301337
self.reader.on_finish()
302338

0 commit comments

Comments
 (0)