Skip to content

Commit a5070e8

Browse files
committed
if padding applied, create separate data.pt file
1 parent 9c56753 commit a5070e8

File tree

1 file changed

+38
-8
lines changed
  • chebai_graph/preprocessing/datasets

1 file changed

+38
-8
lines changed

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -349,41 +349,71 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
349349
if isinstance(p, AtomProperty)
350350
)
351351

352-
in_channels_str = f"in_channels: {n_node_properties}"
352+
in_channels_str = ""
353353
if self.zero_pad_node:
354354
n_node_properties += self.zero_pad_node
355-
in_channels_str += f"(with {self.zero_pad_node} padded zeros)"
355+
in_channels_str += f" (with {self.zero_pad_node} padded zeros)"
356356

357357
if self.random_pad_node:
358358
n_node_properties += self.random_pad_node
359-
in_channels_str += f"(with {self.random_pad_node} random padded values from {self.distribution} distribution)"
359+
in_channels_str += f" (with {self.random_pad_node} random padded values from {self.distribution} distribution)"
360+
361+
in_channels_str = f"in_channels: {n_node_properties}" + in_channels_str
360362

361363
# -------------------------- Count total edge properties
362364
n_edge_properties = sum(
363365
p.encoder.get_encoding_length()
364366
for p in self.properties
365367
if isinstance(p, BondProperty)
366368
)
367-
edge_dim_str = f"edge_dim: {n_edge_properties}"
369+
edge_dim_str = ""
368370

369371
if self.zero_pad_edge:
370372
n_edge_properties += self.zero_pad_edge
371-
edge_dim_str += f"(with {self.zero_pad_edge} padded zeros)"
373+
edge_dim_str += f" (with {self.zero_pad_edge} padded zeros)"
372374

373375
if self.random_pad_edge:
374376
n_edge_properties += self.random_pad_edge
375-
edge_dim_str += f"(with {self.random_pad_edge} random padded values from {self.distribution} distribution)"
377+
edge_dim_str += f" (with {self.random_pad_edge} random padded values from {self.distribution} distribution)"
378+
379+
edge_dim_str = f"edge_dim: {n_edge_properties}" + edge_dim_str
376380

377381
rank_zero_info(
378382
f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n"
379383
f"Use following values for given parameters for model configuration: \n\t"
380-
f"{in_channels_str}, "
381-
f"{edge_dim_str}, "
384+
f"{in_channels_str} \n\t"
385+
f"{edge_dim_str} \n\t"
382386
f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, MoleculeProperty))}"
383387
)
384388

385389
return base_df[base_data[0].keys()].to_dict("records")
386390

391+
@property
392+
def processed_file_names_dict(self) -> dict:
393+
"""
394+
Returns a dictionary for the processed and tokenized data files.
395+
396+
Returns:
397+
dict: A dictionary mapping dataset keys to their respective file names.
398+
For example, {"data": "data.pt"}.
399+
"""
400+
if self.n_token_limit is not None:
401+
return {"data": f"data_maxlen{self.n_token_limit}.pt"}
402+
403+
data_pt_filename = "data"
404+
if self.zero_pad_node:
405+
data_pt_filename += f"_zpn{self.zero_pad_node}"
406+
if self.zero_pad_edge:
407+
data_pt_filename += f"_zpe{self.zero_pad_edge}"
408+
if self.random_pad_node:
409+
data_pt_filename += f"_rpn{self.random_pad_node}"
410+
if self.random_pad_edge:
411+
data_pt_filename += f"_rpe{self.random_pad_edge}"
412+
if self.random_pad_node or self.random_pad_edge:
413+
data_pt_filename += f"_D{self.distribution}"
414+
415+
return {"data": data_pt_filename + ".pt"}
416+
387417

388418
class GraphPropAsPerNodeType(DataPropertiesSetter, ABC):
389419
def __init__(self, properties=None, transform=None, **kwargs):

0 commit comments

Comments
 (0)