Skip to content

Commit 1dc219c

Browse files
Merge pull request #20 from CarloLucibello/cl/data
support multiple feature arrays in GNNGraph
2 parents cba6565 + 0d24bfd commit 1dc219c

File tree

15 files changed

+290
-289
lines changed

15 files changed

+290
-289
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1010
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
12+
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
1213
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -24,6 +25,7 @@ ChainRulesCore = "1"
2425
DataStructures = "0.18"
2526
Flux = "0.12"
2627
KrylovKit = "0.5"
28+
LearnBase = "0.5"
2729
LightGraphs = "1.3"
2830
MacroTools = "0.5"
2931
NNlib = "0.7"

docs/src/gnngraph.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
# Graphs
22

3+
TODO
4+
5+
```@docs
6+
GNNGraph
7+
```
8+

docs/src/messagepassing.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
# Message Passing
22

3+
TODO
4+
5+
```@docs
6+
propagate
7+
```

perf/perf.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function run_single_benchmark(N, c, D, CONV; gtype=:lg)
1111
data = erdos_renyi(N, c / (N-1), seed=17)
1212
X = randn(Float32, D, N)
1313

14-
g = GNNGraph(data; nf=X, graph_type=gtype)
14+
g = GNNGraph(data; ndata=X, graph_type=gtype)
1515
g_gpu = g |> gpu
1616

1717
m = CONV(D => D)

src/GraphNeuralNetworks.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
module GraphNeuralNetworks
22

3-
using Core: apply_type
4-
using NNlib: similar
5-
using LinearAlgebra: similar, fill!
63
using Statistics: mean
74
using LinearAlgebra
85
using SparseArrays
@@ -12,17 +9,17 @@ using CUDA
129
using Flux
1310
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor
1411
using MacroTools: @forward
12+
using LearnBase: getobs
1513
using NNlib, NNlibCUDA
1614
using ChainRulesCore
1715
import LightGraphs
18-
using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv,
19-
adjacency_matrix, degree
16+
using LightGraphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
2017

2118
export
2219
# gnngraph
2320
GNNGraph,
2421
edge_index,
25-
node_feature, edge_feature, global_feature,
22+
node_features, edge_features, global_features,
2623
adjacency_list, normalized_laplacian, scaled_laplacian,
2724
add_self_loops, remove_self_loops,
2825
subgraph,
@@ -52,7 +49,6 @@ export
5249
topk_index
5350

5451

55-
5652

5753
include("gnngraph.jl")
5854
include("graph_conversions.jl")

0 commit comments

Comments
 (0)