Skip to content

Commit 4288689

Browse files
committed
reformatting
1 parent 6df484d commit 4288689

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

chebai/preprocessing/datasets/pubchem.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,7 @@ class PubChemBatched(PubChem):
217217

218218
READER: Type[dr.ChemDataReader] = dr.ChemDataReader
219219

220-
def __init__(
221-
self, train_batch_size=10_000_000, *args, **kwargs
222-
):
220+
def __init__(self, train_batch_size=10_000_000, *args, **kwargs):
223221
super(PubChemBatched, self).__init__(*args, **kwargs)
224222
self.curr_epoch = 0
225223
self.train_batch_size = train_batch_size
@@ -245,7 +243,7 @@ def processed_file_names_dict(self) -> List[str]:
245243
List[str]: List of processed data file names.
246244
"""
247245
train_samples = (
248-
self._k if self._k != self.FULL else 120_000_000 # estimated PubChem size
246+
self._k if self._k != self.FULL else 120_000_000 # estimated PubChem size
249247
) # estimate size
250248
train_samples -= self.val_batch_size + self.test_batch_size
251249
train_batches = (
@@ -279,21 +277,15 @@ def _tokenize_batched(self, data):
279277
batch.append(self.reader.to_data(d))
280278
if i % 1_000_000 == 0 and i > 0:
281279
print(f"Saving batch {i // 1_000_000}")
282-
batch = [
283-
b
284-
for b in batch
285-
if b["features"] is not None
286-
]
280+
batch = [b for b in batch if b["features"] is not None]
287281
if self.n_token_limit is not None:
288-
batch = [b for b in batch if len(b["features"]) <= self.n_token_limit]
282+
batch = [
283+
b for b in batch if len(b["features"]) <= self.n_token_limit
284+
]
289285
yield batch
290286
batch = []
291287
print("Saving final batch")
292-
batch = [
293-
b
294-
for b in batch
295-
if b["features"] is not None
296-
]
288+
batch = [b for b in batch if b["features"] is not None]
297289
if self.n_token_limit is not None:
298290
batch = [b for b in batch if len(b["features"]) <= self.n_token_limit]
299291
yield batch
@@ -348,13 +340,18 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader
348340
DataLoader: A DataLoader object for training data.
349341
"""
350342
return self.dataloader(
351-
"train" if "train" in self.processed_file_names_dict else f"train_{self.curr_epoch}",
343+
(
344+
"train"
345+
if "train" in self.processed_file_names_dict
346+
else f"train_{self.curr_epoch}"
347+
),
352348
shuffle=True,
353349
num_workers=self.num_workers,
354350
persistent_workers=True,
355351
**kwargs,
356352
)
357353

354+
358355
class PubChemDissimilar(PubChem):
359356
"""
360357
Subset of PubChem, but choosing the most dissimilar molecules (according to fingerprint)

0 commit comments

Comments
 (0)