Skip to content

Commit abffcae

Browse files
committed
fix k
1 parent db5434f commit abffcae

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

chebai/preprocessing/datasets/pubchem.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,7 @@ def processed_file_names_dict(self) -> List[str]:
181181
Returns:
182182
List[str]: List of processed data file names.
183183
"""
184-
return {
185-
"train": "train.pt",
186-
"test": "test.pt",
187-
"validation": "validation.pt"
188-
}
184+
return {"train": "train.pt", "test": "test.pt", "validation": "validation.pt"}
189185

190186
def _set_processed_data_props(self):
191187
"""
@@ -221,7 +217,9 @@ class PubChemBatched(PubChem):
221217

222218
READER: Type[dr.ChemDataReader] = dr.ChemDataReader
223219

224-
def __init__(self, pc_train_batch_idx=0, train_batch_size=10_000_000, *args, **kwargs):
220+
def __init__(
221+
self, pc_train_batch_idx=0, train_batch_size=10_000_000, *args, **kwargs
222+
):
225223
super(PubChemBatched, self).__init__(*args, **kwargs)
226224
self.pc_train_batch_idx = pc_train_batch_idx
227225
self.train_batch_size = train_batch_size
@@ -254,7 +252,9 @@ def processed_file_names_dict(self) -> List[str]:
254252
{"train": "train.pt"}
255253
if train_samples <= self.train_batch_size
256254
else {
257-
f"train" if i == self.pc_train_batch_idx else f"train_{i}": f"train_{i}.pt"
255+
(
256+
"train" if i == self.pc_train_batch_idx else f"train_{i}"
257+
): f"train_{i}.pt"
258258
for i in range((train_samples // self.train_batch_size) + 1)
259259
}
260260
)
@@ -275,8 +275,16 @@ def setup_processed(self):
275275
)
276276
del data
277277
test, val = train_test_split(test, train_size=self.test_batch_size)
278-
torch.save(test, os.path.join(self.processed_dir, self.processed_file_names_dict["test"]))
279-
torch.save(val, os.path.join(self.processed_dir, self.processed_file_names_dict["validation"]))
278+
torch.save(
279+
test,
280+
os.path.join(self.processed_dir, self.processed_file_names_dict["test"]),
281+
)
282+
torch.save(
283+
val,
284+
os.path.join(
285+
self.processed_dir, self.processed_file_names_dict["validation"]
286+
),
287+
)
280288

281289
# batch training if necessary
282290
if len(train) > self.train_batch_size:

0 commit comments

Comments
 (0)