Skip to content

Commit 2757ed3

Browse files
committed
fix gat v2 share weights issue
1 parent 7626dad commit 2757ed3

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

chebai_graph/models/gat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def __init__(self, config: dict, **kwargs):
2828
super().__init__(config=config, **kwargs)
2929
self.heads = int(config["heads"])
3030
self.v2 = bool(config["v2"])
31-
self.share_weights = bool(config.get("share_weights", False))
31+
local_kwargs = {}
32+
if self.v2:
33+
local_kwargs["share_weights"] = bool(config.get("share_weights", False))
3234
self.activation = ELU() # Instantiate ELU once for reuse.
3335
self.gat = GAT(
3436
in_channels=self.in_channels,
@@ -40,7 +42,7 @@ def __init__(self, config: dict, **kwargs):
4042
heads=self.heads,
4143
v2=self.v2,
4244
act=self.activation,
43-
share_weights=self.share_weights,
45+
**local_kwargs,
4446
)
4547

4648
def forward(self, batch: dict) -> torch.Tensor:

0 commit comments

Comments
 (0)