Skip to content

Commit d0d39d7

Browse files
adjustment to GINConv
1 parent c15b750 commit d0d39d7

File tree

7 files changed

+39
-22
lines changed

7 files changed

+39
-22
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1313
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
1414
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
16+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1617
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1718
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1819
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
@@ -21,6 +22,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2122
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2223

2324
[compat]
25+
Adapt = "3"
2426
CUDA = "3.3"
2527
ChainRulesCore = "1"
2628
DataStructures = "0.18"

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,6 @@ include("msgpass.jl")
6363
include("layers/basic.jl")
6464
include("layers/conv.jl")
6565
include("layers/pool.jl")
66+
include("deprecations.jl")
6667

6768
end

src/deprecations.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Deprecated in v0.1
2+
3+
@deprecate GINConv(nn; eps=0, aggr=+) GINConv(nn, eps; aggr)

src/layers/conv.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ end
407407

408408

409409
@doc raw"""
410-
GINConv(f; eps = 0f0)
410+
GINConv(f, ϵ; aggr=+)
411411
412412
Graph Isomorphism convolutional layer from paper [How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf)
413413
@@ -420,30 +420,38 @@ where ``f_\Theta`` typically denotes a learnable function, e.g. a linear layer o
420420
# Arguments
421421
422422
- `f`: A (possibly learnable) function acting on node features.
423-
- `eps`: Weighting factor.
423+
- `ϵ`: Weighting factor.
424424
"""
425425
struct GINConv{R<:Real} <: GNNLayer
426426
nn
427-
eps::R
427+
ϵ::R
428+
aggr
428429
end
429430

430431
@functor GINConv
431-
Flux.trainable(l::GINConv) = (nn=l.nn,)
432+
Flux.trainable(l::GINConv) = (l.nn,)
433+
434+
GINConv(nn, ϵ; aggr=+) = GINConv(nn, ϵ, aggr)
432435

433-
function GINConv(nn; eps=0f0)
434-
GINConv(nn, eps)
435-
end
436436

437437
compute_message(l::GINConv, x_i, x_j, e_ij) = x_j
438-
update_node(l::GINConv, m, x) = l.nn((1 + l.eps) * x + m)
438+
update_node(l::GINConv, m, x) = l.nn((1 + ofeltype(x, l.ϵ)) * x + m)
439439

440440
function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
441441
check_num_nodes(g, X)
442-
X, _ = propagate(l, g, +, X)
442+
X, _ = propagate(l, g, l.aggr, X)
443443
X
444444
end
445445

446446

447+
function Base.show(io::IO, l::GINConv)
448+
print(io, "GINConv($(l.nn)")
449+
print(io, ", $(l.ϵ)")
450+
print(io, ")")
451+
end
452+
453+
454+
447455
@doc raw"""
448456
NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform)
449457

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,6 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
6767
end
6868
return data
6969
end
70+
71+
72+
ofeltype(x, y) = convert(float(eltype(x)), y)

test/examples/node_classification_cora.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ end
1717

1818
# arguments for the `train` function
1919
Base.@kwdef mutable struct Args
20-
η = 1f-3 # learning rate
21-
epochs = 20 # number of epochs
20+
η = 5f-3 # learning rate
21+
epochs = 10 # number of epochs
2222
seed = 17 # set seed > 0 for reproducibility
2323
usecuda = false # if true use cuda (if available)
24-
nhidden = 128 # dimension of hidden features
24+
nhidden = 64 # dimension of hidden features
2525
end
2626

2727
function train(Layer; verbose=false, kws...)
@@ -49,7 +49,7 @@ function train(Layer; verbose=false, kws...)
4949

5050
## DEFINE MODEL
5151
model = GNNChain(Layer(nin, nhidden),
52-
Dropout(0.5),
52+
# Dropout(0.5),
5353
Layer(nhidden, nhidden),
5454
Dense(nhidden, nout)) |> device
5555

@@ -84,16 +84,16 @@ for Layer in [
8484
(nin, nout) -> GraphConv(nin => nout, relu, aggr=mean),
8585
(nin, nout) -> SAGEConv(nin => nout, relu),
8686
(nin, nout) -> GATConv(nin => nout, relu),
87-
(nin, nout) -> GATConv(nin => nout÷2, relu, heads=2),
88-
(nin, nout) -> GINConv(Dense(nin, nout, relu)),
89-
(nin, nout) -> ChebConv(nin => nout, 3),
90-
(nin, nout) -> ResGatedGraphConv(nin => nout, relu),
87+
(nin, nout) -> GINConv(Dense(nin, nout, relu), 0.01, aggr=mean),
88+
(nin, nout) -> ChebConv(nin => nout, 2),
89+
(nin, nout) -> ResGatedGraphConv(nin => nout, relu),
9190
# (nin, nout) -> NNConv(nin => nout), # needs edge features
9291
# (nin, nout) -> GatedGraphConv(nout, 2), # needs nin = nout
9392
# (nin, nout) -> EdgeConv(Dense(2nin, nout, relu)), # Fits the traning set but does not generalize well
9493
]
95-
train_res, test_res = train(Layer, verbose=false)
96-
# @show Layer(2,2) train_res, test_res
94+
95+
@show Layer(2,2)
96+
train_res, test_res = train(Layer, verbose=true)
9797
@test train_res.acc > 95
9898
@test test_res.acc > 70
9999
end

test/layers/conv.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@
111111

112112
@testset "GINConv" begin
113113
nn = Dense(in_channel, out_channel)
114-
eps = 0.001f0
115-
l = GINConv(nn, eps=eps)
114+
115+
l = GINConv(nn, 0.01f0, aggr=mean)
116116
for g in test_graphs
117-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes), exclude_grad_fields=[:eps])
117+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
118118
end
119119

120120
@test !in(:eps, Flux.trainable(l))

0 commit comments

Comments
 (0)