Skip to content

Commit bed5768

Browse files
authored
Merge pull request #2 from ChEB-AI/thesis_augmented_gnn
Thesis: Integration Chemical Knowledge into GNN
2 parents f1ce5b8 + 9c6d915 commit bed5768

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+8447
-432
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,8 @@ cython_debug/
169169
electra_pretrained.ckpt
170170
.isort.cfg
171171
/.vscode
172+
173+
*.err
174+
*.out
175+
*.sh
176+
*.ckpt

chebai_graph/models/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from .augmented import (
2+
GATAugNodePoolGraphPred,
3+
GATGraphNodeFGNodePoolGraphPred,
4+
ResGatedAugNodePoolGraphPred,
5+
ResGatedGraphNodeFGNodePoolGraphPred,
6+
)
7+
from .dynamic_gni import ResGatedDynamicGNIGraphPred
8+
from .gat import GATGraphPred
9+
from .resgated import ResGatedGraphPred
10+
11+
__all__ = [
12+
"ResGatedGraphPred",
13+
"ResGatedAugNodePoolGraphPred",
14+
"ResGatedGraphNodeFGNodePoolGraphPred",
15+
"GATGraphPred",
16+
"GATAugNodePoolGraphPred",
17+
"GATGraphNodeFGNodePoolGraphPred",
18+
"ResGatedDynamicGNIGraphPred",
19+
]

chebai_graph/models/augmented.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from .base import AugmentedNodePoolingNet, GraphNodeFGNodePoolingNet
2+
from .gat import GATGraphPred
3+
from .resgated import ResGatedGraphPred
4+
5+
6+
class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred):
7+
"""
8+
Combines:
9+
- AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes).
10+
- ResGatedGraphPred: Residual gated network for final graph prediction.
11+
"""
12+
13+
...
14+
15+
16+
class GATAugNodePoolGraphPred(AugmentedNodePoolingNet, GATGraphPred):
17+
"""
18+
Combines:
19+
- AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes).
20+
- GATGraphPred: Graph attention network for final graph prediction.
21+
"""
22+
23+
...
24+
25+
26+
class ResGatedGraphNodeFGNodePoolGraphPred(
27+
GraphNodeFGNodePoolingNet, ResGatedGraphPred
28+
):
29+
"""
30+
Combines:
31+
- GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes).
32+
- ResGatedGraphPred: Residual gated network for final graph prediction.
33+
"""
34+
35+
...
36+
37+
38+
class GATGraphNodeFGNodePoolGraphPred(GraphNodeFGNodePoolingNet, GATGraphPred):
39+
"""
40+
Combines:
41+
- GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes).
42+
- GATGraphPred: Graph attention network for final graph prediction.
43+
"""
44+
45+
...

0 commit comments

Comments
 (0)