Skip to content

Commit 0bd5956

Browse files
committed
gatv2-constrainted
1 parent 29c3ea3 commit 0bd5956

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

chebai_graph/models/gat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(self, config: dict, **kwargs):
2828
super().__init__(config=config, **kwargs)
2929
self.heads = int(config["heads"])
3030
self.v2 = bool(config["v2"])
31+
self.share_weights = bool(config.get("share_weights", False))
3132
self.activation = ELU() # Instantiate ELU once for reuse.
3233
self.gat = GAT(
3334
in_channels=self.in_channels,
@@ -39,6 +40,7 @@ def __init__(self, config: dict, **kwargs):
3940
heads=self.heads,
4041
v2=self.v2,
4142
act=self.activation,
43+
share_weights=self.share_weights
4244
)
4345

4446
def forward(self, batch: dict) -> torch.Tensor:

0 commit comments

Comments
 (0)