Skip to content

Commit 65de740

Browse files
rename FeaturedGraph to GNNGraph
1 parent 275b39f commit 65de740

22 files changed

+457
-458
lines changed

docs/make.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ makedocs(;
77
modules=[GraphNeuralNetworks],
88
sitename = "GraphNeuralNetworks.jl",
99
pages = ["Home" => "index.md",
10-
"Graphs" => "graphs.md",
11-
"Message passing" => "messagepassing.md",
12-
"Building models" => "models.md",
10+
"GNNGraph" => "gnngraph.md",
11+
"Message Passing" => "messagepassing.md",
12+
"Building Models" => "models.md",
1313
"API Reference" =>
1414
[
15-
"Graphs" => "api/graphs.md",
15+
"GNNGraph" => "api/gnngraph.md",
1616
"Convolutional Layers" => "api/conv.md",
1717
"Pooling Layers" => "api/pool.md",
1818
],

docs/src/api/graphs.md renamed to docs/src/api/gnngraph.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ Pages = ["api/graphs.md"]
77

88
```@autodocs
99
Modules = [GraphNeuralNetworks]
10-
Pages = ["featuredgraph.jl"]
10+
Pages = ["gnngraph.jl"]
1111
Private = false
1212
```
File renamed without changes.

examples/cora.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@ function GNN(; nin, nhidden, nout)
2323
Dense(nhidden, nout))
2424
end
2525

26-
function (net::GNN)(fg, x)
27-
x = net.conv1(fg, x)
26+
function (net::GNN)(g, x)
27+
x = net.conv1(g, x)
2828
x = dropout(x, 0.5)
29-
x = net.conv2(fg, x)
29+
x = net.conv2(g, x)
3030
x = net.dense(x)
3131
return x
3232
end
3333

34-
function eval_loss_accuracy(X, y, ids, model, fg)
35-
= model(fg, X)
34+
function eval_loss_accuracy(X, y, ids, model, g)
35+
= model(g, X)
3636
l = logitcrossentropy(ŷ[:,ids], y[:,ids])
3737
acc = mean(onecold(ŷ[:,ids] |> cpu) .== onecold(y[:,ids] |> cpu))
3838
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
@@ -64,7 +64,7 @@ function train(; kws...)
6464
end
6565

6666
data = Cora.dataset()
67-
fg = FeaturedGraph(data.adjacency_list) |> device
67+
g = GNNGraph(data.adjacency_list) |> device
6868
X = data.node_features |> device
6969
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
7070
train_ids = data.train_indices |> device
@@ -78,19 +78,19 @@ function train(; kws...)
7878
ps = Flux.params(model)
7979
opt = ADAM(args.η)
8080

81-
@info "NUM NODES: $(fg.num_nodes) NUM EDGES: $(fg.num_edges)"
81+
@info "NUM NODES: $(g.num_nodes) NUM EDGES: $(g.num_edges)"
8282

8383
function report(epoch)
84-
train = eval_loss_accuracy(X, y, train_ids, model, fg)
85-
test = eval_loss_accuracy(X, y, test_ids, model, fg)
84+
train = eval_loss_accuracy(X, y, train_ids, model, g)
85+
test = eval_loss_accuracy(X, y, test_ids, model, g)
8686
println("Epoch: $epoch Train: $(train) Test: $(test)")
8787
end
8888

8989
## TRAINING
9090
report(0)
9191
for epoch in 1:args.epochs
9292
gs = Flux.gradient(ps) do
93-
= model(fg, X)
93+
= model(g, X)
9494
logitcrossentropy(ŷ[:,train_ids], ytrain)
9595
end
9696

perf/perf.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,20 @@ BenchmarkTools.ratio(x, ::Missing) = 0.0
88
BenchmarkTools.ratio(::Missing, ::Missing) = missing
99

1010
function run_single_benchmark(N, c, D, CONV; gtype=:lg)
11-
g = erdos_renyi(N, c / (N-1), seed=17)
11+
data = erdos_renyi(N, c / (N-1), seed=17)
1212
X = randn(Float32, D, N)
1313

14-
fg = FeaturedGraph(g; nf=X, graph_type=gtype)
15-
fg_gpu = fg |> gpu
14+
g = GNNGraph(data; nf=X, graph_type=gtype)
15+
g_gpu = g |> gpu
1616

1717
m = CONV(D => D)
1818
m_gpu = m |> gpu
1919

2020
res = Dict()
21-
res["CPU"] = @benchmark $m($fg)
21+
res["CPU"] = @benchmark $m($g)
2222

2323
try [GCNConv, GraphConv, GATConv]
24-
res["GPU"] = @benchmark CUDA.@sync($m_gpu($fg_gpu)) teardown=(GC.gc(); CUDA.reclaim())
24+
res["GPU"] = @benchmark CUDA.@sync($m_gpu($g_gpu)) teardown=(GC.gc(); CUDA.reclaim())
2525
catch
2626
res["GPU"] = missing
2727
end

src/GraphNeuralNetworks.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv
1616
adjacency_matrix, degree
1717

1818
export
19-
# featured_graph
20-
FeaturedGraph,
19+
# gnngraph
20+
GNNGraph,
2121
edge_index,
2222
node_feature, edge_feature, global_feature,
2323
adjacency_list, normalized_laplacian, scaled_laplacian,
@@ -56,7 +56,7 @@ export
5656
bypass_graph
5757

5858

59-
include("featuredgraph.jl")
59+
include("gnngraph.jl")
6060
include("graph_conversions.jl")
6161
include("utils.jl")
6262

0 commit comments

Comments
 (0)