Skip to content

Commit 9338ed7

Browse files
use GNNlib in GNN.jl (#464)
* use GNNlib in GNN.jl * cleanup * ported all graph convs * workflow * fix * fix gcn_con * fix gcn_con * add comments
1 parent a9700f9 commit 9338ed7

File tree

12 files changed

+265
-1212
lines changed

12 files changed

+265
-1212
lines changed

.github/workflows/test_GraphNeuralNetworks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
# dev mono repo versions
3838
pkg"registry up"
3939
Pkg.update()
40-
pkg"dev ./GNNGraphs ."
40+
pkg"dev ./GNNGraphs ./GNNlib ."
4141
Pkg.test("GraphNeuralNetworks"; coverage=true)
4242
- uses: julia-actions/julia-processcoverage@v1
4343
with:

GNNlib/src/GNNlib.jl

Lines changed: 51 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,70 +12,65 @@ using .GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
1212
check_num_nodes, check_num_edges,
1313
EType, NType # for heteroconvs
1414

15-
export
16-
# utils
17-
reduce_nodes,
18-
reduce_edges,
19-
softmax_nodes,
20-
softmax_edges,
21-
broadcast_nodes,
22-
broadcast_edges,
23-
softmax_edge_neighbors,
24-
# msgpass
25-
apply_edges,
26-
aggregate_neighbors,
27-
propagate,
28-
copy_xj,
29-
copy_xi,
30-
xi_dot_xj,
31-
xi_sub_xj,
32-
xj_sub_xi,
33-
e_mul_xj,
34-
w_mul_xj
15+
include("utils.jl")
16+
export reduce_nodes,
17+
reduce_edges,
18+
softmax_nodes,
19+
softmax_edges,
20+
broadcast_nodes,
21+
broadcast_edges,
22+
softmax_edge_neighbors
3523

24+
include("msgpass.jl")
25+
export apply_edges,
26+
aggregate_neighbors,
27+
propagate,
28+
copy_xj,
29+
copy_xi,
30+
xi_dot_xj,
31+
xi_sub_xj,
32+
xj_sub_xi,
33+
e_mul_xj,
34+
w_mul_xj
35+
3636
## The following methods are defined but not exported
3737

38-
# # layers/basic
39-
# dot_decoder,
40-
41-
# # layers/conv
42-
# agnn_conv,
43-
# cg_conv,
44-
# cheb_conv,
45-
# edge_conv,
46-
# egnn_conv,
47-
# gat_conv,
48-
# gatv2_conv,
49-
# gated_graph_conv,
50-
# gcn_conv,
51-
# gin_conv,
52-
# gmm_conv,
53-
# graph_conv,
54-
# megnet_conv,
55-
# nn_conv,
56-
# res_gated_graph_conv,
57-
# sage_conv,
58-
# sg_conv,
59-
# transformer_conv,
38+
include("layers/basic.jl")
39+
export dot_decoder
6040

61-
# # layers/temporalconv
62-
# a3tgcn_conv,
41+
include("layers/conv.jl")
42+
export agnn_conv,
43+
cg_conv,
44+
cheb_conv,
45+
d_conv,
46+
edge_conv,
47+
egnn_conv,
48+
gat_conv,
49+
gatv2_conv,
50+
gated_graph_conv,
51+
gcn_conv,
52+
gin_conv,
53+
gmm_conv,
54+
graph_conv,
55+
megnet_conv,
56+
nn_conv,
57+
res_gated_graph_conv,
58+
sage_conv,
59+
sg_conv,
60+
tag_conv,
61+
transformer_conv
6362

64-
# # layers/pool
65-
# global_pool,
66-
# global_attention_pool,
67-
# set2set_pool,
68-
# topk_pool,
69-
# topk_index,
63+
include("layers/temporalconv.jl")
64+
export a3tgcn_conv
7065

66+
include("layers/pool.jl")
67+
export global_pool,
68+
global_attention_pool,
69+
set2set_pool,
70+
topk_pool,
71+
topk_index
7172

72-
include("utils.jl")
73-
include("layers/basic.jl")
74-
include("layers/conv.jl")
7573
# include("layers/heteroconv.jl") # no functional part at the moment
76-
include("layers/temporalconv.jl")
77-
include("layers/pool.jl")
78-
include("msgpass.jl")
7974

8075
end #module
8176

0 commit comments

Comments
 (0)