Skip to content

Commit 0653b4f

Browse files
committed
GATv2 amg pool
1 parent 2994190 commit 0653b4f

File tree

4 files changed

+12
-98
lines changed

4 files changed

+12
-98
lines changed

chebai_graph/models/__init__.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,17 @@
11
from .augmented import (
22
GATAugNodePoolGraphPred,
3+
GATGraphNodeFGNodePoolGraphPred,
34
ResGatedAugNodePoolGraphPred,
4-
ResGatedAugOnlyPoolGraphPred,
5-
ResGatedFGNodeNoGraphNodeGraphPred,
6-
ResGatedFGNodePoolGraphPred,
7-
ResGatedFGOnlyPoolGraphPred,
85
ResGatedGraphNodeFGNodePoolGraphPred,
9-
ResGatedGraphNodeNoFGNodeGraphPred,
10-
ResGatedGraphNodeOnlyPoolGraphPred,
11-
ResGatedGraphNodePoolGraphPred,
126
)
137
from .gat import GATGraphPred
148
from .resgated import ResGatedGraphPred
159

1610
__all__ = [
17-
"GATGraphPred",
1811
"ResGatedGraphPred",
19-
"ResGatedFGNodeNoGraphNodeGraphPred",
2012
"ResGatedAugNodePoolGraphPred",
2113
"ResGatedGraphNodeFGNodePoolGraphPred",
22-
"ResGatedGraphNodePoolGraphPred",
23-
"ResGatedGraphNodeNoFGNodeGraphPred",
24-
"ResGatedFGNodePoolGraphPred",
25-
"ResGatedAugOnlyPoolGraphPred",
26-
"ResGatedGraphNodeOnlyPoolGraphPred",
27-
"ResGatedFGOnlyPoolGraphPred",
14+
"GATGraphPred",
2815
"GATAugNodePoolGraphPred",
16+
"GATGraphNodeFGNodePoolGraphPred",
2917
]

chebai_graph/models/augmented.py

Lines changed: 7 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,12 @@
1-
from .base import (
2-
AugmentedNodePoolingNet,
3-
AugmentedOnlyPoolingNet,
4-
FGNodePoolingNet,
5-
FGNodePoolingNoGraphNodeNet,
6-
FGOnlyPoolingNet,
7-
GraphNodeFGNodePoolingNet,
8-
GraphNodeNoFGNodePoolingNet,
9-
GraphNodeOnlyPoolingNet,
10-
GraphNodePoolingNet,
11-
)
1+
from .base import AugmentedNodePoolingNet, GraphNodeFGNodePoolingNet
122
from .gat import GATGraphPred
133
from .resgated import ResGatedGraphPred
144

155

166
class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred):
177
"""
188
Combines:
19-
- AugmentedNodePoolingNet: Pools atom and augmented node embeddings with molecule attributes.
9+
- AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes).
2010
- ResGatedGraphPred: Residual gated network for final graph prediction.
2111
"""
2212

@@ -26,94 +16,30 @@ class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred):
2616
class GATAugNodePoolGraphPred(AugmentedNodePoolingNet, GATGraphPred):
2717
"""
2818
Combines:
29-
- AugmentedNodePoolingNet: Pools atom and augmented node embeddings with molecule attributes.
19+
- AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes).
3020
- GATGraphPred: Graph attention network for final graph prediction.
3121
"""
3222

3323
...
3424

3525

36-
class ResGatedGraphNodePoolGraphPred(GraphNodePoolingNet, ResGatedGraphPred):
37-
"""
38-
Combines:
39-
- GraphNodePoolingNet: Pools atom and graph node embeddings with molecule attributes.
40-
- ResGatedGraphPred: Residual gated network for final graph prediction.
41-
"""
42-
43-
...
44-
45-
46-
class ResGatedFGNodePoolGraphPred(FGNodePoolingNet, ResGatedGraphPred):
47-
"""
48-
Combines:
49-
- FGNodePoolingNet: Pools functional group nodes and other nodes with molecule attributes.
50-
- ResGatedGraphPred: Residual gated network for final graph prediction.
51-
"""
52-
53-
...
54-
55-
5626
class ResGatedGraphNodeFGNodePoolGraphPred(
5727
GraphNodeFGNodePoolingNet, ResGatedGraphPred
5828
):
5929
"""
6030
Combines:
61-
- GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes with molecule attributes.
62-
- ResGatedGraphPred: Residual gated network for final graph prediction.
63-
"""
64-
65-
...
66-
67-
68-
class ResGatedGraphNodeNoFGNodeGraphPred(
69-
GraphNodeNoFGNodePoolingNet, ResGatedGraphPred
70-
):
71-
"""
72-
Combines:
73-
- GraphNodeNoFGNodePoolingNet: Pools atom and graph nodes, excluding functional groups.
74-
- ResGatedGraphPred: Residual gated network for final graph prediction.
75-
"""
76-
77-
...
78-
79-
80-
class ResGatedFGNodeNoGraphNodeGraphPred(
81-
FGNodePoolingNoGraphNodeNet, ResGatedGraphPred
82-
):
83-
"""
84-
Combines:
85-
- FGNodePoolingNoGraphNodeNet: Pools atom and functional group nodes, excluding graph nodes.
86-
- ResGatedGraphPred: Residual gated network for final graph prediction.
87-
"""
88-
89-
...
90-
91-
92-
class ResGatedAugOnlyPoolGraphPred(AugmentedOnlyPoolingNet, ResGatedGraphPred):
93-
"""
94-
Combines:
95-
- AugmentedOnlyPoolingNet: Pools only augmented nodes with molecule attributes.
31+
- GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes).
9632
- ResGatedGraphPred: Residual gated network for final graph prediction.
9733
"""
9834

9935
...
10036

10137

102-
class ResGatedGraphNodeOnlyPoolGraphPred(GraphNodeOnlyPoolingNet, ResGatedGraphPred):
38+
class GATGraphNodeFGNodePoolGraphPred(GraphNodeFGNodePoolingNet, GATGraphPred):
10339
"""
10440
Combines:
105-
- GraphNodeOnlyPoolingNet: Pools only graph nodes with molecule attributes.
106-
- ResGatedGraphPred: Residual gated network for final graph prediction.
107-
"""
108-
109-
...
110-
111-
112-
class ResGatedFGOnlyPoolGraphPred(FGOnlyPoolingNet, ResGatedGraphPred):
113-
"""
114-
Combines:
115-
- FGOnlyPoolingNet: Pools only functional group nodes with molecule attributes.
116-
- ResGatedGraphPred: Residual gated network for final graph prediction.
41+
- GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes).
42+
- GATGraphPred: Graph attention network for final graph prediction.
11743
"""
11844

11945
...

chebai_graph/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _get_lin_seq_input_dim(
200200
201201
Includes:
202202
- Atom embeddings
203-
- Molecular attributes
203+
- Molecular attributes (if any)
204204
- Augmented node embeddings
205205
206206
Args:

chebai_graph/models/gat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, config: dict, **kwargs):
4040
heads=self.heads,
4141
v2=self.v2,
4242
act=self.activation,
43-
share_weights=self.share_weights
43+
share_weights=self.share_weights,
4444
)
4545

4646
def forward(self, batch: dict) -> torch.Tensor:

0 commit comments

Comments
 (0)