Skip to content

Commit 0b5f650

Browse files
committed
add padding with zeros or randomness for node and edges
1 parent bc3981b commit 0b5f650

File tree

2 files changed

+90
-12
lines changed

2 files changed

+90
-12
lines changed

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -186,16 +186,50 @@ def _after_setup(self, **kwargs) -> None:
186186

187187
class GraphPropertiesMixIn(DataPropertiesSetter, ABC):
188188
def __init__(
189-
self, properties=None, transform=None, zero_pad_atom: int = None, **kwargs
189+
self,
190+
properties=None,
191+
transform=None,
192+
zero_pad_node: int = None,
193+
zero_pad_edge: int = None,
194+
random_pad_node: int = None,
195+
random_pad_edge: int = None,
196+
distribution: str = "normal",
197+
**kwargs,
190198
):
191199
super().__init__(properties, transform, **kwargs)
192-
self.zero_pad_atom = int(zero_pad_atom) if zero_pad_atom is not None else None
193-
if self.zero_pad_atom:
200+
self.zero_pad_node = int(zero_pad_node) if zero_pad_node else None
201+
if self.zero_pad_node:
202+
print(
203+
f"[Info] Node-level features will be zero-padded with "
204+
f"{self.zero_pad_node} additional dimensions."
205+
)
206+
207+
self.zero_pad_edge = int(zero_pad_edge) if zero_pad_edge else None
208+
if self.zero_pad_edge:
194209
print(
195-
f"[Info] Atom-level features will be zero-padded with "
196-
f"{self.zero_pad_atom} additional dimensions."
210+
f"[Info] Edge-level features will be zero-padded with "
211+
f"{self.zero_pad_edge} additional dimensions."
197212
)
198213

214+
self.random_pad_edge = int(random_pad_edge) if random_pad_edge else None
215+
self.random_pad_node = int(random_pad_node) if random_pad_node else None
216+
if self.random_pad_node or self.random_pad_edge:
217+
assert (
218+
distribution is not None
219+
and distribution in RandomFeatureInitializationReader.DISTRIBUTIONS
220+
), "When using random padding, a valid distribution must be specified."
221+
self.distribution = distribution
222+
if self.random_pad_node:
223+
print(
224+
f"[Info] Node-level features will be padded with "
225+
f"{self.random_pad_node} additional dimensions initialized from {self.distribution} distribution."
226+
)
227+
if self.random_pad_edge:
228+
print(
229+
f"[Info] Edge-level features will be padded with "
230+
f"{self.random_pad_edge} additional dimensions initialized from {self.distribution} distribution."
231+
)
232+
199233
if self.properties:
200234
print(
201235
f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}"
@@ -242,8 +276,24 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
242276
else:
243277
raise TypeError(f"Unsupported property type: {type(property).__name__}")
244278

245-
if self.zero_pad_atom is not None:
246-
x = torch.cat([x, torch.zeros((x.shape[0], self.zero_pad_atom))], dim=1)
279+
if self.zero_pad_node:
280+
x = torch.cat([x, torch.zeros((x.shape[0], self.zero_pad_node))], dim=1)
281+
282+
if self.zero_pad_edge:
283+
edge_attr = torch.cat(
284+
[edge_attr, torch.zeros((edge_attr.shape[0], self.zero_pad_edge))],
285+
dim=1,
286+
)
287+
288+
if self.random_pad_node:
289+
random_pad = torch.empty((x.shape[0], self.random_pad_node))
290+
RandomFeatureInitializationReader.random_gni(random_pad, self.distribution)
291+
x = torch.cat([x, random_pad], dim=1)
292+
293+
if self.random_pad_edge:
294+
random_pad = torch.empty((edge_attr.shape[0], self.random_pad_edge))
295+
RandomFeatureInitializationReader.random_gni(random_pad, self.distribution)
296+
edge_attr = torch.cat([edge_attr, random_pad], dim=1)
247297

248298
return GeomData(
249299
x=x,
@@ -291,18 +341,44 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
291341
prop_lengths = [
292342
(prop.name, prop.encoder.get_encoding_length()) for prop in self.properties
293343
]
344+
345+
# -------------------------- Count total node properties
294346
n_node_properties = sum(
295347
p.encoder.get_encoding_length()
296348
for p in self.properties
297349
if isinstance(p, AtomProperty)
298350
)
299-
if self.zero_pad_atom:
300-
n_node_properties += self.zero_pad_atom
351+
352+
in_channels_str = f"in_channels: {n_node_properties}"
353+
if self.zero_pad_node:
354+
n_node_properties += self.zero_pad_node
355+
in_channels_str += f"(with {self.zero_pad_node} padded zeros)"
356+
357+
if self.random_pad_node:
358+
n_node_properties += self.random_pad_node
359+
in_channels_str += f"(with {self.random_pad_node} random padded values from {self.distribution} distribution)"
360+
361+
# -------------------------- Count total edge properties
362+
n_edge_properties = sum(
363+
p.encoder.get_encoding_length()
364+
for p in self.properties
365+
if isinstance(p, BondProperty)
366+
)
367+
edge_dim_str = f"edge_dim: {n_edge_properties}"
368+
369+
if self.zero_pad_edge:
370+
n_edge_properties += self.zero_pad_edge
371+
edge_dim_str += f"(with {self.zero_pad_edge} padded zeros)"
372+
373+
if self.random_pad_edge:
374+
n_edge_properties += self.random_pad_edge
375+
edge_dim_str += f"(with {self.random_pad_edge} random padded values from {self.distribution} distribution)"
376+
301377
rank_zero_info(
302378
f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n"
303379
f"Use following values for given parameters for model configuration: \n\t"
304-
f"in_channels: {n_node_properties} (with {self.zero_pad_atom} padded zeros) , "
305-
f"edge_dim: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, BondProperty))}, "
380+
f"{in_channels_str}, "
381+
f"{edge_dim_str}, "
306382
f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, MoleculeProperty))}"
307383
)
308384

chebai_graph/preprocessing/reader/static_gni.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414

1515
class RandomFeatureInitializationReader(GraphPropertyReader):
16+
DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform"]
17+
1618
def __init__(
1719
self,
1820
num_node_properties: int,
@@ -26,7 +28,7 @@ def __init__(
2628
self.num_node_properties = num_node_properties
2729
self.num_bond_properties = num_bond_properties
2830
self.num_molecule_properties = num_molecule_properties
29-
assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"]
31+
assert distribution in self.DISTRIBUTIONS
3032
self.distribution = distribution
3133

3234
def name(self) -> str:

0 commit comments

Comments
 (0)