Skip to content

Commit e09e23c

Browse files
committed
gat aug node pool class
1 parent fb27521 commit e09e23c

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

chebai_graph/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .augmented import (
2+
GATAugNodePoolGraphPred,
23
ResGatedAugNodePoolGraphPred,
34
ResGatedAugOnlyPoolGraphPred,
45
ResGatedFGNodeNoGraphNodeGraphPred,
@@ -24,4 +25,5 @@
2425
"ResGatedAugOnlyPoolGraphPred",
2526
"ResGatedGraphNodeOnlyPoolGraphPred",
2627
"ResGatedFGOnlyPoolGraphPred",
28+
"GATAugNodePoolGraphPred",
2729
]

chebai_graph/models/augmented.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
GraphNodeOnlyPoolingNet,
1010
GraphNodePoolingNet,
1111
)
12+
from .gat import GATGraphPred
1213
from .resgated import ResGatedGraphPred
1314

1415

@@ -22,6 +23,16 @@ class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred):
2223
...
2324

2425

26+
class GATAugNodePoolGraphPred(AugmentedNodePoolingNet, GATGraphPred):
27+
"""
28+
Combines:
29+
- AugmentedNodePoolingNet: Pools atom and augmented node embeddings with molecule attributes.
30+
- GATGraphPred: Graph attention network for final graph prediction.
31+
"""
32+
33+
...
34+
35+
2536
class ResGatedGraphNodePoolGraphPred(GraphNodePoolingNet, ResGatedGraphPred):
2637
"""
2738
Combines:

0 commit comments

Comments
 (0)