Skip to content

Commit ca3f305

Browse files
committed
zero padding for atom properties
1 parent dfdd810 commit ca3f305

File tree

1 file changed

+22
-1
lines changed
  • chebai_graph/preprocessing/datasets

1 file changed

+22
-1
lines changed

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,17 @@ def _after_setup(self, **kwargs) -> None:
178178

179179

180180
class GraphPropertiesMixIn(DataPropertiesSetter, ABC):
181+
def __init__(
182+
self, properties=None, transform=None, zero_pad_atom: int = None, **kwargs
183+
):
184+
super().__init__(properties, transform, **kwargs)
185+
self.zero_pad_atom = int(zero_pad_atom) if zero_pad_atom is not None else None
186+
if self.zero_pad_atom:
187+
print(
188+
f"[Info] Atom-level features will be zero-padded with "
189+
f"{self.zero_pad_atom} additional dimensions."
190+
)
191+
181192
def _merge_props_into_base(self, row: pd.Series) -> GeomData:
182193
"""
183194
Merge encoded molecular properties into the GeomData object.
@@ -219,6 +230,9 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
219230
else:
220231
raise TypeError(f"Unsupported property type: {type(property).__name__}")
221232

233+
if self.zero_pad_atom is not None:
234+
x = torch.cat([x, torch.zeros((x.shape[0], self.zero_pad_atom))], dim=1)
235+
222236
return GeomData(
223237
x=x,
224238
edge_index=geom_data.edge_index,
@@ -265,10 +279,17 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
265279
prop_lengths = [
266280
(prop.name, prop.encoder.get_encoding_length()) for prop in self.properties
267281
]
282+
n_node_properties = sum(
283+
p.encoder.get_encoding_length()
284+
for p in self.properties
285+
if isinstance(p, AtomProperty)
286+
)
287+
if self.zero_pad_atom:
288+
n_node_properties += self.zero_pad_atom
268289
rank_zero_info(
269290
f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n"
270291
f"Use following values for given parameters for model configuration: \n\t"
271-
f"in_channels: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, AtomProperty))}, "
292+
f"in_channels: {n_node_properties} (with {self.zero_pad_atom} padded zeros) , "
272293
f"edge_dim: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, BondProperty))}, "
273294
f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, MoleculeProperty))}"
274295
)

0 commit comments

Comments
 (0)