Skip to content

Commit 1ba24e0

Browse files
committed
include zero padding random init reader static method
1 parent ba9de36 commit 1ba24e0

File tree

3 files changed

+50
-78
lines changed

3 files changed

+50
-78
lines changed

chebai_graph/models/dynamic_gni.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,35 +41,35 @@ def __init__(self, config: dict[str, Any], **kwargs: Any):
4141

4242
if not self.complete_randomness:
4343
assert (
44-
"random_pad_node" in config or "random_pad_edge" in config
45-
), "Missing 'random_pad_node' or 'random_pad_edge' in config when complete_randomness is False"
46-
self.random_pad_node = (
47-
int(config["random_pad_node"])
48-
if config.get("random_pad_node") is not None
44+
"pad_node_features" in config or "pad_edge_features" in config
45+
), "Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False"
46+
self.pad_node_features = (
47+
int(config["pad_node_features"])
48+
if config.get("pad_node_features") is not None
4949
else None
5050
)
51-
if self.random_pad_node is not None:
51+
if self.pad_node_features is not None:
5252
print(
53-
f"[Info] Node features will be padded with {self.random_pad_node} "
53+
f"[Info] Node features will be padded with {self.pad_node_features} "
5454
f"new set of random features from distribution {self.distribution} "
5555
f"in each forward pass."
5656
)
5757

58-
self.random_pad_edge = (
59-
int(config["random_pad_edge"])
60-
if config.get("random_pad_edge") is not None
58+
self.pad_edge_features = (
59+
int(config["pad_edge_features"])
60+
if config.get("pad_edge_features") is not None
6161
else None
6262
)
63-
if self.random_pad_edge is not None:
63+
if self.pad_edge_features is not None:
6464
print(
65-
f"[Info] Edge features will be padded with {self.random_pad_edge} "
65+
f"[Info] Edge features will be padded with {self.pad_edge_features} "
6666
f"new set of random features from distribution {self.distribution} "
6767
f"in each forward pass."
6868
)
6969

7070
assert (
71-
self.random_pad_node > 0 or self.random_pad_edge > 0
72-
), "'random_pad_node' or 'random_pad_edge' must be positive integers"
71+
self.pad_node_features > 0 or self.pad_edge_features > 0
72+
), "'pad_node_features' or 'pad_edge_features' must be positive integers"
7373

7474
self.resgated: BasicGNN = ResGatedModel(
7575
in_channels=self.in_channels,
@@ -111,21 +111,21 @@ def forward(self, batch: dict[str, Any]) -> Tensor:
111111
new_edge_attr, self.distribution
112112
)
113113
else:
114-
if self.random_pad_node is not None:
114+
if self.pad_node_features is not None:
115115
pad_node = torch.empty(
116116
graph_data.x.shape[0],
117-
self.random_pad_node,
117+
self.pad_node_features,
118118
device=self.device,
119119
)
120120
RandomFeatureInitializationReader.random_gni(
121121
pad_node, self.distribution
122122
)
123123
new_x = torch.cat((graph_data.x, pad_node), dim=1)
124124

125-
if self.random_pad_edge is not None:
125+
if self.pad_edge_features is not None:
126126
pad_edge = torch.empty(
127127
graph_data.edge_attr.shape[0],
128-
self.random_pad_edge,
128+
self.pad_edge_features,
129129
device=self.device,
130130
)
131131
RandomFeatureInitializationReader.random_gni(

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 29 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -189,45 +189,29 @@ def __init__(
189189
self,
190190
properties=None,
191191
transform=None,
192-
zero_pad_node: int = None,
193-
zero_pad_edge: int = None,
194-
random_pad_node: int = None,
195-
random_pad_edge: int = None,
192+
pad_node_features: int = None,
193+
pad_edge_features: int = None,
196194
distribution: str = "normal",
197195
**kwargs,
198196
):
199197
super().__init__(properties, transform, **kwargs)
200-
self.zero_pad_node = int(zero_pad_node) if zero_pad_node else None
201-
if self.zero_pad_node:
202-
print(
203-
f"[Info] Node-level features will be zero-padded with "
204-
f"{self.zero_pad_node} additional dimensions."
205-
)
206-
207-
self.zero_pad_edge = int(zero_pad_edge) if zero_pad_edge else None
208-
if self.zero_pad_edge:
209-
print(
210-
f"[Info] Edge-level features will be zero-padded with "
211-
f"{self.zero_pad_edge} additional dimensions."
212-
)
213-
214-
self.random_pad_edge = int(random_pad_edge) if random_pad_edge else None
215-
self.random_pad_node = int(random_pad_node) if random_pad_node else None
216-
if self.random_pad_node or self.random_pad_edge:
198+
self.pad_edge_features = int(pad_edge_features) if pad_edge_features else None
199+
self.pad_node_features = int(pad_node_features) if pad_node_features else None
200+
if self.pad_node_features or self.pad_edge_features:
217201
assert (
218202
distribution is not None
219203
and distribution in RandomFeatureInitializationReader.DISTRIBUTIONS
220-
), "When using random padding, a valid distribution must be specified."
204+
), "When using padding for features, a valid distribution must be specified."
221205
self.distribution = distribution
222-
if self.random_pad_node:
206+
if self.pad_node_features:
223207
print(
224-
f"[Info] Node-level features will be padded with "
225-
f"{self.random_pad_node} additional dimensions initialized from {self.distribution} distribution."
208+
f"[Info] Node-level features will be padded with random"
209+
f"{self.pad_node_features} values from {self.distribution} distribution."
226210
)
227-
if self.random_pad_edge:
211+
if self.pad_edge_features:
228212
print(
229-
f"[Info] Edge-level features will be padded with "
230-
f"{self.random_pad_edge} additional dimensions initialized from {self.distribution} distribution."
213+
f"[Info] Edge-level features will be padded with random"
214+
f"{self.pad_edge_features} values from {self.distribution} distribution."
231215
)
232216

233217
if self.properties:
@@ -276,24 +260,19 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
276260
else:
277261
raise TypeError(f"Unsupported property type: {type(property).__name__}")
278262

279-
if self.zero_pad_node:
280-
x = torch.cat([x, torch.zeros((x.shape[0], self.zero_pad_node))], dim=1)
281-
282-
if self.zero_pad_edge:
283-
edge_attr = torch.cat(
284-
[edge_attr, torch.zeros((edge_attr.shape[0], self.zero_pad_edge))],
285-
dim=1,
263+
if self.pad_node_features:
264+
padding_values = torch.empty((x.shape[0], self.pad_node_features))
265+
RandomFeatureInitializationReader.random_gni(
266+
padding_values, self.distribution
286267
)
268+
x = torch.cat([x, padding_values], dim=1)
287269

288-
if self.random_pad_node:
289-
random_pad = torch.empty((x.shape[0], self.random_pad_node))
290-
RandomFeatureInitializationReader.random_gni(random_pad, self.distribution)
291-
x = torch.cat([x, random_pad], dim=1)
292-
293-
if self.random_pad_edge:
294-
random_pad = torch.empty((edge_attr.shape[0], self.random_pad_edge))
295-
RandomFeatureInitializationReader.random_gni(random_pad, self.distribution)
296-
edge_attr = torch.cat([edge_attr, random_pad], dim=1)
270+
if self.pad_edge_features:
271+
padding_values = torch.empty((edge_attr.shape[0], self.pad_edge_features))
272+
RandomFeatureInitializationReader.random_gni(
273+
padding_values, self.distribution
274+
)
275+
edge_attr = torch.cat([edge_attr, padding_values], dim=1)
297276

298277
return GeomData(
299278
x=x,
@@ -350,13 +329,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
350329
)
351330

352331
in_channels_str = ""
353-
if self.zero_pad_node:
354-
n_node_properties += self.zero_pad_node
355-
in_channels_str += f" (with {self.zero_pad_node} padded zeros)"
356-
357-
if self.random_pad_node:
358-
n_node_properties += self.random_pad_node
359-
in_channels_str += f" (with {self.random_pad_node} random padded values from {self.distribution} distribution)"
332+
if self.pad_node_features:
333+
n_node_properties += self.pad_node_features
334+
in_channels_str += f" (with {self.pad_node_features} padded random values from {self.distribution} distribution)"
360335

361336
in_channels_str = f"in_channels: {n_node_properties}" + in_channels_str
362337

@@ -367,14 +342,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
367342
if isinstance(p, BondProperty)
368343
)
369344
edge_dim_str = ""
370-
371-
if self.zero_pad_edge:
372-
n_edge_properties += self.zero_pad_edge
373-
edge_dim_str += f" (with {self.zero_pad_edge} padded zeros)"
374-
375-
if self.random_pad_edge:
376-
n_edge_properties += self.random_pad_edge
377-
edge_dim_str += f" (with {self.random_pad_edge} random padded values from {self.distribution} distribution)"
345+
if self.pad_edge_features:
346+
n_edge_properties += self.pad_edge_features
347+
edge_dim_str += f" (with {self.pad_edge_features} padded random values from {self.distribution} distribution)"
378348

379349
edge_dim_str = f"edge_dim: {n_edge_properties}" + edge_dim_str
380350

chebai_graph/preprocessing/reader/static_gni.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
class RandomFeatureInitializationReader(GraphPropertyReader):
16-
DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform"]
16+
DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"]
1717

1818
def __init__(
1919
self,
@@ -74,5 +74,7 @@ def random_gni(tensor: torch.Tensor, distribution: str) -> None:
7474
torch.nn.init.xavier_normal_(tensor)
7575
elif distribution == "xavier_uniform":
7676
torch.nn.init.xavier_uniform_(tensor)
77+
elif distribution == "zeros":
78+
torch.nn.init.zeros_(tensor)
7779
else:
7880
raise ValueError("Unknown distribution type")

0 commit comments

Comments
 (0)