Skip to content

Commit db5434f

Browse files
committed
update file name system for pubchem batched
1 parent bb73784 commit db5434f

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

chebai/preprocessing/datasets/pubchem.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,16 @@ def raw_file_names(self) -> List[str]:
176176
return ["smiles.txt"]
177177

178178
@property
179-
def processed_file_names(self) -> List[str]:
179+
def processed_file_names_dict(self) -> List[str]:
180180
"""
181181
Returns:
182182
List[str]: List of processed data file names.
183183
"""
184-
return ["test.pt", "train.pt", "validation.pt"]
184+
return {
185+
"train": "train.pt",
186+
"test": "test.pt",
187+
"validation": "validation.pt"
188+
}
185189

186190
def _set_processed_data_props(self):
187191
"""
@@ -215,9 +219,12 @@ def _perform_data_preparation(self, *args, **kwargs):
215219
class PubChemBatched(PubChem):
216220
"""Store train data as batches of 10m, validation and test should each be 100k max"""
217221

218-
def __init__(self, *args, **kwargs):
222+
READER: Type[dr.ChemDataReader] = dr.ChemDataReader
223+
224+
def __init__(self, pc_train_batch_idx=0, train_batch_size=10_000_000, *args, **kwargs):
219225
super(PubChemBatched, self).__init__(*args, **kwargs)
220-
self.train_batch_size = 10_000_000
226+
self.pc_train_batch_idx = pc_train_batch_idx
227+
self.train_batch_size = train_batch_size
221228
if self.k != self.FULL:
222229
self.val_batch_size = (
223230
100_000
@@ -234,7 +241,7 @@ def __init__(self, *args, **kwargs):
234241
self.test_batch_size = 100_000
235242

236243
@property
237-
def processed_file_names(self) -> List[str]:
244+
def processed_file_names_dict(self) -> List[str]:
238245
"""
239246
Returns:
240247
List[str]: List of processed data file names.
@@ -244,14 +251,16 @@ def processed_file_names(self) -> List[str]:
244251
) # estimate size
245252
train_samples -= self.val_batch_size + self.test_batch_size
246253
train_batches = (
247-
["train.pt"]
254+
{"train": "train.pt"}
248255
if train_samples <= self.train_batch_size
249-
else [
250-
f"train_{i}.pt"
256+
else {
257+
f"train" if i == self.pc_train_batch_idx else f"train_{i}": f"train_{i}.pt"
251258
for i in range((train_samples // self.train_batch_size) + 1)
252-
]
259+
}
253260
)
254-
return train_batches + ["test.pt", "validation.pt"]
261+
train_batches["test"] = "test.pt"
262+
train_batches["validation"] = "validation.pt"
263+
return train_batches
255264

256265
def setup_processed(self):
257266
"""
@@ -266,8 +275,8 @@ def setup_processed(self):
266275
)
267276
del data
268277
test, val = train_test_split(test, train_size=self.test_batch_size)
269-
torch.save(test, os.path.join(self.processed_dir, "test.pt"))
270-
torch.save(val, os.path.join(self.processed_dir, "validation.pt"))
278+
torch.save(test, os.path.join(self.processed_dir, self.processed_file_names_dict["test"]))
279+
torch.save(val, os.path.join(self.processed_dir, self.processed_file_names_dict["validation"]))
271280

272281
# batch training if necessary
273282
if len(train) > self.train_batch_size:

0 commit comments

Comments
 (0)