Skip to content

Commit 9c56753

Browse files
committed
for dynamic randomness set option for padding existing features with random features
1 parent 0b5f650 commit 9c56753

File tree

1 file changed

+69
-17
lines changed

1 file changed

+69
-17
lines changed

chebai_graph/models/dynamic_gni.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,34 @@ class ResGatedDynamicGNI(GraphModelBase):
2525
def __init__(self, config: dict[str, Any], **kwargs: Any):
2626
super().__init__(config=config, **kwargs)
2727
self.activation = ELU() # Instantiate ELU once for reuse.
28+
2829
distribution = config.get("distribution", "normal")
29-
assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"]
30+
assert distribution in RandomFeatureInitializationReader.DISTRIBUTIONS, (
31+
f"Unsupported distribution: {distribution}. "
32+
f"Choose from {RandomFeatureInitializationReader.DISTRIBUTIONS}."
33+
)
3034
self.distribution = distribution
3135

36+
self.complete_randomness = config.get("complete_randomness", True)
37+
38+
if not self.complete_randomness:
39+
assert (
40+
"random_pad_node" in config or "random_pad_edge" in config
41+
), "Missing 'random_pad_node' or 'random_pad_edge' in config when complete_randomness is False"
42+
self.random_pad_node = (
43+
int(config["random_pad_node"])
44+
if config.get("random_pad_node") is not None
45+
else None
46+
)
47+
self.random_pad_edge = (
48+
int(config["random_pad_edge"])
49+
if config.get("random_pad_edge") is not None
50+
else None
51+
)
52+
assert (
53+
self.random_pad_node > 0 or self.random_pad_edge > 0
54+
), "'random_pad_node' or 'random_pad_edge' must be positive integers"
55+
3256
self.resgated: BasicGNN = ResGatedModel(
3357
in_channels=self.in_channels,
3458
hidden_channels=self.hidden_channels,
@@ -52,24 +76,52 @@ def forward(self, batch: dict[str, Any]) -> Tensor:
5276
graph_data = batch["features"][0]
5377
assert isinstance(graph_data, GraphData), "Expected GraphData instance"
5478

55-
random_x = torch.empty(
56-
graph_data.x.shape[0], graph_data.x.shape[1], device=self.device
57-
)
58-
RandomFeatureInitializationReader.random_gni(random_x, self.distribution)
59-
60-
random_edge_attr = torch.empty(
61-
graph_data.edge_attr.shape[0],
62-
graph_data.edge_attr.shape[1],
63-
device=self.device,
64-
)
65-
RandomFeatureInitializationReader.random_gni(
66-
random_edge_attr, self.distribution
67-
)
68-
79+
new_x = None
80+
new_edge_attr = None
81+
if self.complete_randomness:
82+
new_x = torch.empty(
83+
graph_data.x.shape[0], graph_data.x.shape[1], device=self.device
84+
)
85+
RandomFeatureInitializationReader.random_gni(new_x, self.distribution)
86+
87+
new_edge_attr = torch.empty(
88+
graph_data.edge_attr.shape[0],
89+
graph_data.edge_attr.shape[1],
90+
device=self.device,
91+
)
92+
RandomFeatureInitializationReader.random_gni(
93+
new_edge_attr, self.distribution
94+
)
95+
else:
96+
if self.random_pad_node is not None:
97+
pad_node = torch.empty(
98+
graph_data.x.shape[0],
99+
self.random_pad_node,
100+
device=self.device,
101+
)
102+
RandomFeatureInitializationReader.random_gni(
103+
pad_node, self.distribution
104+
)
105+
new_x = torch.cat((graph_data.x, pad_node), dim=1)
106+
107+
if self.random_pad_edge is not None:
108+
pad_edge = torch.empty(
109+
graph_data.edge_attr.shape[0],
110+
self.random_pad_edge,
111+
device=self.device,
112+
)
113+
RandomFeatureInitializationReader.random_gni(
114+
pad_edge, self.distribution
115+
)
116+
new_edge_attr = torch.cat((graph_data.edge_attr, pad_edge), dim=1)
117+
118+
assert (
119+
new_x is not None and new_edge_attr is not None
120+
), "Feature initialization failed"
69121
out = self.resgated(
70-
x=random_x.float(),
122+
x=new_x.float(),
71123
edge_index=graph_data.edge_index.long(),
72-
edge_attr=random_edge_attr.float(),
124+
edge_attr=new_edge_attr.float(),
73125
)
74126

75127
return self.activation(out)

0 commit comments

Comments
 (0)