From e90f6416fef61b96052b0598b37e5ec3afeb3195 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sun, 4 Aug 2024 17:45:30 +0530 Subject: [PATCH 01/23] megnet WIP --- GNNLux/src/layers/conv.jl | 43 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 83c3efddc..c497cb8f9 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -628,3 +628,46 @@ function Base.show(io::IO, l::GINConv) print(io, ", $(l.ϵ)") print(io, ")") end + +@concrete struct MEGNetConv{TE, TV, A} <: GNNLayer + ϕe::TE + ϕv::TV + aggr::A + num_features::NamedTuple +end + +MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) + +function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) + nin, nout = ch + ϕe = Chain(Dense(3nin, nout, relu), + Dense(nout, nout)) + + ϕv = Chain(Dense(nin + nout, nout, relu), + Dense(nout, nout)) + + num_features = (in = in_size, edge = edge_feat_size, out = out_size, + hidden = hidden_size) + + return MEGNetConv(ϕe, ϕv; aggr, num_features) +end + + +LuxCore.outputsize(l::MegNetConv) = (l.num_features.out,) + +(l::MegNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st) + +function (l::MegNetConv)(g, x, e, ps, st) + ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) + ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv)) + m = (; ϕe, ϕv, l.residual, l.num_features) + return GNNlib.megnet_conv(m, g, x, e), st +end + +function Base.show(io::IO, l::MegNetConv) + ne = l.num_features.edge + nin = l.num_features.in + nout = l.num_features.out + print(io, "MegNetConv(($nin, $ne) => $nout") + print(io, ")") +end \ No newline at end of file From 6b1af1b608737ff5a5ce944a29045898d5d41bbb Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sun, 4 Aug 2024 17:49:44 +0530 Subject: [PATCH 02/23] fix --- GNNLux/src/layers/conv.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index c497cb8f9..b8266b3b2 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -646,8 +646,7 @@ function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) ϕv = Chain(Dense(nin + nout, nout, relu), Dense(nout, nout)) - num_features = (in = in_size, edge = edge_feat_size, out = out_size, - hidden = hidden_size) + num_features = (in = nin, out = nout) return MEGNetConv(ϕe, ϕv; aggr, num_features) end From 9b76cf54a051e4f37128d238ef1d3fa6cccf5b63 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sun, 4 Aug 2024 17:54:46 +0530 Subject: [PATCH 03/23] fix --- GNNLux/src/layers/conv.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index b8266b3b2..279d2beb3 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -629,11 +629,12 @@ function Base.show(io::IO, l::GINConv) print(io, ")") end -@concrete struct MEGNetConv{TE, TV, A} <: GNNLayer +@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)} + in_dims::Int + out_dims::Int ϕe::TE ϕv::TV aggr::A - num_features::NamedTuple end MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) @@ -646,15 +647,13 @@ function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) ϕv = Chain(Dense(nin + nout, nout, relu), Dense(nout, nout)) - num_features = (in = nin, out = nout) - - return MEGNetConv(ϕe, ϕv; aggr, num_features) + return MEGNetConv(nin, nout, ϕe, ϕv; aggr) end LuxCore.outputsize(l::MegNetConv) = (l.num_features.out,) -(l::MegNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st) +(l::MegNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st) # check function (l::MegNetConv)(g, x, e, ps, st) ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) From 4904f911e2bf45437067aa390238062a9633d74e Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sun, 4 Aug 2024 17:56:45 +0530 Subject: [PATCH 04/23] fix output --- GNNLux/src/layers/conv.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 279d2beb3..7ea85e0a5 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -663,9 +663,8 @@ function (l::MegNetConv)(g, x, e, ps, st) end function Base.show(io::IO, l::MegNetConv) - ne = l.num_features.edge - nin = l.num_features.in - nout = l.num_features.out - print(io, "MegNetConv(($nin, $ne) => $nout") + nin = l.in_dims + nout = l.out_dims + print(io, "MegNetConv(", l.in_dims, " => ", l.out_dims) print(io, ")") end \ No newline at end of file From d5cfb7bc4014598a5fa9b47d3e7db60cdca3e798 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 04:24:29 +0530 Subject: [PATCH 05/23] wip --- GNNLux/src/GNNLux.jl | 2 +- GNNLux/src/layers/conv.jl | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index d8970095c..e932451be 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -30,7 +30,7 @@ export AGNNConv, GINConv, # GMMConv, GraphConv, - # MEGNetConv, + #MEGNetConv, # NNConv, # ResGatedGraphConv, # SAGEConv, diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 7ea85e0a5..6d7e6f70e 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -650,11 +650,8 @@ function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) return MEGNetConv(nin, nout, ϕe, ϕv; aggr) end - LuxCore.outputsize(l::MegNetConv) = (l.num_features.out,) -(l::MegNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st) # check - function (l::MegNetConv)(g, x, e, ps, st) ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv)) @@ -665,6 +662,6 @@ end function Base.show(io::IO, l::MegNetConv) nin = l.in_dims nout = l.out_dims - print(io, "MegNetConv(", l.in_dims, " => ", l.out_dims) + print(io, "MegNetConv(", nin, " => ", nout) print(io, ")") end \ No newline at end of file From 3296f2eccc36011f8347036a85c327e645169bee Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 05:05:09 +0530 Subject: [PATCH 06/23] temporary changes to run tests --- GNNGraphs/test/runtests.jl | 4 ++-- GNNLux/src/GNNLux.jl | 2 +- GNNLux/src/layers/conv.jl | 13 ++++++++----- GNNLux/test/layers/conv_tests.jl | 20 ++++++++++++++++++-- GNNlib/test/runtests.jl | 2 +- 5 files changed, 30 insertions(+), 11 deletions(-) diff --git a/GNNGraphs/test/runtests.jl b/GNNGraphs/test/runtests.jl index 0c648d2a4..da90a56a3 100644 --- a/GNNGraphs/test/runtests.jl +++ b/GNNGraphs/test/runtests.jl @@ -23,7 +23,7 @@ const ACUMatrix{T} = Union{CuMatrix{T}, CUDA.CUSPARSE.CuSparseMatrix{T}} ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets include("test_utils.jl") - +""" tests = [ "chainrules", "datastore", @@ -39,7 +39,7 @@ tests = [ "mldatasets", "ext/SimpleWeightedGraphs" ] - +""" !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") for graph_type in (:coo, :dense, :sparse) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index e932451be..3aa3251d0 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -30,7 +30,7 @@ export AGNNConv, GINConv, # GMMConv, GraphConv, - #MEGNetConv, + MEGNetConv, # NNConv, # ResGatedGraphConv, # SAGEConv, diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 6d7e6f70e..a42af3db8 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -650,18 +650,21 @@ function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) return MEGNetConv(nin, nout, ϕe, ϕv; aggr) end -LuxCore.outputsize(l::MegNetConv) = (l.num_features.out,) - -function (l::MegNetConv)(g, x, e, ps, st) +function (l::MEGNetConv)(g, x, e, ps, st) ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv)) m = (; ϕe, ϕv, l.residual, l.num_features) return GNNlib.megnet_conv(m, g, x, e), st end -function Base.show(io::IO, l::MegNetConv) + +LuxCore.outputsize(l::MEGNetConv) = (l.out_dims,) + +(l::MEGNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st) + +function Base.show(io::IO, l::MEGNetConv) nin = l.in_dims nout = l.out_dims - print(io, "MegNetConv(", nin, " => ", nout) + print(io, "MEGNetConv(", nin, " => ", nout) print(io, ")") end \ No newline at end of file diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 9f010f39e..f62a12eb2 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -4,7 +4,7 @@ in_dims = 3 out_dims = 5 x = randn(rng, Float32, in_dims, 10) - + """ @testset "GCNConv" begin l = GCNConv(in_dims => out_dims, relu) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) @@ -53,7 +53,22 @@ @test size(hnew) == (hout, g.num_nodes) @test size(xnew) == (in_dims, g.num_nodes) end - + """ + @testset "MEGNetConv" begin + in_dims = 6 + out_dims = 8 + + l = MEGNetConv(in_dims => out_dims) + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + + (x_new, e_new), st_new = l(g, x, ps, st) + + @test size(x_new) == (out_dims, g.num_nodes) + @test size(e_new) == (out_dims, g.num_edges) + end + """ @testset "GATConv" begin x = randn(rng, Float32, 6, 10) @@ -93,4 +108,5 @@ l = GINConv(nn, 0.5) test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) end + """ end diff --git a/GNNlib/test/runtests.jl b/GNNlib/test/runtests.jl index e4c4512b4..32276f937 100644 --- a/GNNlib/test/runtests.jl +++ b/GNNlib/test/runtests.jl @@ -3,4 +3,4 @@ using Test using ReTestItems using Random, Statistics -runtests(GNNlib) +#runtests(GNNlib) From 73a3d0ec12e9004c2062f612783d6af442ed70ab Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 05:11:49 +0530 Subject: [PATCH 07/23] testing --- GNNGraphs/test/runtests.jl | 4 +- GNNLux/test/layers/basic_tests.jl | 3 ++ GNNLux/test/layers/conv_tests.jl | 90 ------------------------------- GNNlib/test/runtests.jl | 2 +- 4 files changed, 6 insertions(+), 93 deletions(-) diff --git a/GNNGraphs/test/runtests.jl b/GNNGraphs/test/runtests.jl index da90a56a3..0c648d2a4 100644 --- a/GNNGraphs/test/runtests.jl +++ b/GNNGraphs/test/runtests.jl @@ -23,7 +23,7 @@ const ACUMatrix{T} = Union{CuMatrix{T}, CUDA.CUSPARSE.CuSparseMatrix{T}} ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets include("test_utils.jl") -""" + tests = [ "chainrules", "datastore", @@ -39,7 +39,7 @@ tests = [ "mldatasets", "ext/SimpleWeightedGraphs" ] -""" + !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") for graph_type in (:coo, :dense, :sparse) diff --git a/GNNLux/test/layers/basic_tests.jl b/GNNLux/test/layers/basic_tests.jl index 9f59f3b10..405359af8 100644 --- a/GNNLux/test/layers/basic_tests.jl +++ b/GNNLux/test/layers/basic_tests.jl @@ -1,4 +1,6 @@ @testitem "layers/basic" setup=[SharedTestSetup] begin + @test 1==1 + """ rng = StableRNG(17) g = rand_graph(10, 40, seed=17) x = randn(rng, Float32, 3, 10) @@ -16,4 +18,5 @@ c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3)) test_lux_layer(rng, c, g, x, outputsize=(3,), container=true) end + """ end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index f62a12eb2..cd5f69c40 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -4,56 +4,7 @@ in_dims = 3 out_dims = 5 x = randn(rng, Float32, in_dims, 10) - """ - @testset "GCNConv" begin - l = GCNConv(in_dims => out_dims, relu) - test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) - end - - @testset "ChebConv" begin - l = ChebConv(in_dims => out_dims, 2) - test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) - end - - @testset "GraphConv" begin - l = GraphConv(in_dims => out_dims, relu) - test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) - end - - @testset "AGNNConv" begin - l = AGNNConv(init_beta=1.0f0) - test_lux_layer(rng, l, g, x, sizey=(in_dims, 10)) - end - - @testset "EdgeConv" begin - nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims)) - l = EdgeConv(nn, aggr = +) - test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true) - end - - @testset "CGConv" begin - l = CGConv(in_dims => in_dims, residual = true) - test_lux_layer(rng, l, g, x, outputsize=(in_dims,), container=true) - end - - @testset "DConv" begin - l = DConv(in_dims => out_dims, 2) - test_lux_layer(rng, l, g, x, outputsize=(5,)) - end - @testset "EGNNConv" begin - hin = 6 - hout = 7 - hidden = 8 - l = EGNNConv(hin => hout, hidden) - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) - h = randn(rng, Float32, hin, g.num_nodes) - (hnew, xnew), stnew = l(g, h, x, ps, st) - @test size(hnew) == (hout, g.num_nodes) - @test size(xnew) == (in_dims, g.num_nodes) - end - """ @testset "MEGNetConv" begin in_dims = 6 out_dims = 8 @@ -68,45 +19,4 @@ @test size(x_new) == (out_dims, g.num_nodes) @test size(e_new) == (out_dims, g.num_edges) end - """ - @testset "GATConv" begin - x = randn(rng, Float32, 6, 10) - - l = GATConv(6 => 8, heads=2) - test_lux_layer(rng, l, g, x, outputsize=(16,)) - - l = GATConv(6 => 8, heads=2, concat=false, dropout=0.5) - test_lux_layer(rng, l, g, x, outputsize=(8,)) - - #TODO test edge - end - - @testset "GATv2Conv" begin - x = randn(rng, Float32, 6, 10) - - l = GATv2Conv(6 => 8, heads=2) - test_lux_layer(rng, l, g, x, outputsize=(16,)) - - l = GATv2Conv(6 => 8, heads=2, concat=false, dropout=0.5) - test_lux_layer(rng, l, g, x, outputsize=(8,)) - - #TODO test edge - end - - @testset "SGConv" begin - l = SGConv(in_dims => out_dims, 2) - test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) - end - - @testset "GatedGraphConv" begin - l = GatedGraphConv(in_dims, 3) - test_lux_layer(rng, l, g, x, outputsize=(in_dims,)) - end - - @testset "GINConv" begin - nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) - l = GINConv(nn, 0.5) - test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) - end - """ end diff --git a/GNNlib/test/runtests.jl b/GNNlib/test/runtests.jl index 32276f937..e4c4512b4 100644 --- a/GNNlib/test/runtests.jl +++ b/GNNlib/test/runtests.jl @@ -3,4 +3,4 @@ using Test using ReTestItems using Random, Statistics -#runtests(GNNlib) +runtests(GNNlib) From 96a9233697af5798eeac3f07fabc8595c9caa7fb Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 05:13:27 +0530 Subject: [PATCH 08/23] test --- GNNLux/test/layers/conv_tests.jl | 90 ++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index cd5f69c40..ad1bc7a2b 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -4,6 +4,96 @@ in_dims = 3 out_dims = 5 x = randn(rng, Float32, in_dims, 10) + """ + @testset "GCNConv" begin + l = GCNConv(in_dims => out_dims, relu) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end + + @testset "ChebConv" begin + l = ChebConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end + + @testset "GraphConv" begin + l = GraphConv(in_dims => out_dims, relu) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end + + @testset "AGNNConv" begin + l = AGNNConv(init_beta=1.0f0) + test_lux_layer(rng, l, g, x, sizey=(in_dims, 10)) + end + + @testset "EdgeConv" begin + nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims)) + l = EdgeConv(nn, aggr = +) + test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true) + end + + @testset "CGConv" begin + l = CGConv(in_dims => in_dims, residual = true) + test_lux_layer(rng, l, g, x, outputsize=(in_dims,), container=true) + end + + @testset "DConv" begin + l = DConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(5,)) + end + + @testset "EGNNConv" begin + hin = 6 + hout = 7 + hidden = 8 + l = EGNNConv(hin => hout, hidden) + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + h = randn(rng, Float32, hin, g.num_nodes) + (hnew, xnew), stnew = l(g, h, x, ps, st) + @test size(hnew) == (hout, g.num_nodes) + @test size(xnew) == (in_dims, g.num_nodes) + end + + @testset "GATConv" begin + x = randn(rng, Float32, 6, 10) + + l = GATConv(6 => 8, heads=2) + test_lux_layer(rng, l, g, x, outputsize=(16,)) + + l = GATConv(6 => 8, heads=2, concat=false, dropout=0.5) + test_lux_layer(rng, l, g, x, outputsize=(8,)) + + #TODO test edge + end + + @testset "GATv2Conv" begin + x = randn(rng, Float32, 6, 10) + + l = GATv2Conv(6 => 8, heads=2) + test_lux_layer(rng, l, g, x, outputsize=(16,)) + + l = GATv2Conv(6 => 8, heads=2, concat=false, dropout=0.5) + test_lux_layer(rng, l, g, x, outputsize=(8,)) + + #TODO test edge + end + + @testset "SGConv" begin + l = SGConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end + + @testset "GatedGraphConv" begin + l = GatedGraphConv(in_dims, 3) + test_lux_layer(rng, l, g, x, outputsize=(in_dims,)) + end + + @testset "GINConv" begin + nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) + l = GINConv(nn, 0.5) + test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) + end + """ @testset "MEGNetConv" begin in_dims = 6 From a1fc342c538f403b7ad3969dba1b68d3fd6da9a1 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 05:24:15 +0530 Subject: [PATCH 09/23] test --- GNNLux/src/layers/conv.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index a42af3db8..e66cdfcbd 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -637,7 +637,8 @@ end aggr::A end -MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) +# 'mean' not defined +#MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) nin, nout = ch From 97cc7697e37172c7ff18a533403b4232c20385fd Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 05:32:12 +0530 Subject: [PATCH 10/23] mean --- GNNLux/src/layers/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index e66cdfcbd..b05f46020 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -640,7 +640,7 @@ end # 'mean' not defined #MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) -function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) +function MEGNetConv(ch::Pair{Int, Int}; aggr) nin, nout = ch ϕe = Chain(Dense(3nin, nout, relu), Dense(nout, nout)) From bd2c3351551533e4fa1ed77724660af074952bd1 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 05:38:00 +0530 Subject: [PATCH 11/23] mean --- GNNLux/src/GNNLux.jl | 1 + GNNLux/src/layers/conv.jl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 3aa3251d0..b24ed4118 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -1,6 +1,7 @@ module GNNLux using ConcreteStructs: @concrete using NNlib: NNlib, sigmoid, relu, swish +using Statistics: mean using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize, initialparameters, initialstates, parameterlength, statelength using Lux: Lux, Chain, Dense, GRUCell, diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index b05f46020..e66cdfcbd 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -640,7 +640,7 @@ end # 'mean' not defined #MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) -function MEGNetConv(ch::Pair{Int, Int}; aggr) +function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) nin, nout = ch ϕe = Chain(Dense(3nin, nout, relu), Dense(nout, nout)) From 5c0c5a8eef09b0ce6e15ef59c8d21296738cf66e Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 05:45:03 +0530 Subject: [PATCH 12/23] fix --- GNNLux/src/layers/conv.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index e66cdfcbd..d1f839c9d 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -637,8 +637,9 @@ end aggr::A end -# 'mean' not defined -#MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) +function MEGNetConv(in_dims::Int, out_dims::Int, ϕe::TE, ϕv::TV; aggr::A = mean) where {TE, TV, A} + return MEGNetConv{TE, TV, A}(in_dims, out_dims, ϕe, ϕv, aggr) +end function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) nin, nout = ch @@ -648,7 +649,7 @@ function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) ϕv = Chain(Dense(nin + nout, nout, relu), Dense(nout, nout)) - return MEGNetConv(nin, nout, ϕe, ϕv; aggr) + return MEGNetConv(nin, nout, ϕe, ϕv, aggr=aggr) end function (l::MEGNetConv)(g, x, e, ps, st) From 79f3115f66b4a1cf93a51032dc50444ebb2a1ce1 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 05:55:19 +0530 Subject: [PATCH 13/23] fix --- GNNLux/src/layers/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index d1f839c9d..a4fe957e9 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -655,7 +655,7 @@ end function (l::MEGNetConv)(g, x, e, ps, st) ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv)) - m = (; ϕe, ϕv, l.residual, l.num_features) + m = (; ϕe, ϕv) return GNNlib.megnet_conv(m, g, x, e), st end From 578968b6479ddc80d8174ddfa67c41f8a7e0a275 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 05:56:51 +0530 Subject: [PATCH 14/23] fix --- GNNLux/src/layers/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index a4fe957e9..9a7259689 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -655,7 +655,7 @@ end function (l::MEGNetConv)(g, x, e, ps, st) ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv)) - m = (; ϕe, ϕv) + m = (; ϕe, ϕv, l.aggr) return GNNlib.megnet_conv(m, g, x, e), st end From 59ee768d1c02920de93490349b8981c9a62536ab Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 06:07:41 +0530 Subject: [PATCH 15/23] added edge check --- GNNLux/src/layers/conv.jl | 2 +- GNNLux/test/layers/conv_tests.jl | 3 --- GNNlib/src/layers/conv.jl | 13 +++++++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 9a7259689..30564ae48 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -655,7 +655,7 @@ end function (l::MEGNetConv)(g, x, e, ps, st) ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv)) - m = (; ϕe, ϕv, l.aggr) + m = (; ϕe, ϕv, aggr=l.aggr) return GNNlib.megnet_conv(m, g, x, e), st end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index ad1bc7a2b..ca1ed68d6 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -96,9 +96,6 @@ """ @testset "MEGNetConv" begin - in_dims = 6 - out_dims = 8 - l = MEGNetConv(in_dims => out_dims) ps = LuxCore.initialparameters(rng, l) diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 50b5b34aa..4ad3a8768 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -355,18 +355,23 @@ end ####################### MegNetConv ###################################### -function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) +function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, Nothing}=nothing) check_num_nodes(g, x) - ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e + if isnothing(e) + num_edges = g.num_edges + e = zeros(eltype(x), 0, num_edges) # Empty matrix with correct number of columns + end + + ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e l.ϕe(vcat(xi, xj, e)) end - xᵉ = aggregate_neighbors(g, l.aggr, ē) + xᵉ = aggregate_neighbors(g, l.aggr, ē) x̄ = l.ϕv(vcat(x, xᵉ)) - return x̄, ē + return x̄, ē end ####################### GMMConv ###################################### From df4b2d25b2b602007719d94fcfa85ebd9a545789 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 06:16:07 +0530 Subject: [PATCH 16/23] test --- GNNLux/test/layers/conv_tests.jl | 4 ++-- GNNlib/src/layers/conv.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index ca1ed68d6..2a867209d 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -103,7 +103,7 @@ (x_new, e_new), st_new = l(g, x, ps, st) - @test size(x_new) == (out_dims, g.num_nodes) - @test size(e_new) == (out_dims, g.num_edges) + #@test size(x_new) == (out_dims, g.num_nodes) + #@test size(e_new) == (out_dims, g.num_edges) end end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 4ad3a8768..ebd9ed94a 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -360,7 +360,7 @@ function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, if isnothing(e) num_edges = g.num_edges - e = zeros(eltype(x), 0, num_edges) # Empty matrix with correct number of columns + e = zeros(eltype(x), 0, num_edges) end ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e From b41b1e089164f693616e88cf16fee71e39bf78ac Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 8 Aug 2024 06:22:22 +0530 Subject: [PATCH 17/23] fix --- GNNLux/test/layers/conv_tests.jl | 4 ++-- GNNlib/src/layers/conv.jl | 7 +------ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 2a867209d..ca1ed68d6 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -103,7 +103,7 @@ (x_new, e_new), st_new = l(g, x, ps, st) - #@test size(x_new) == (out_dims, g.num_nodes) - #@test size(e_new) == (out_dims, g.num_edges) + @test size(x_new) == (out_dims, g.num_nodes) + @test size(e_new) == (out_dims, g.num_edges) end end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index ebd9ed94a..2c13e62fc 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -357,12 +357,7 @@ end function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, Nothing}=nothing) check_num_nodes(g, x) - - if isnothing(e) - num_edges = g.num_edges - e = zeros(eltype(x), 0, num_edges) - end - + ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e l.ϕe(vcat(xi, xj, e)) end From 62b3405ec2c0edf8180958ba44bd5037ba06fe66 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 9 Aug 2024 15:10:34 +0530 Subject: [PATCH 18/23] Update basic_tests.jl --- GNNLux/test/layers/basic_tests.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/GNNLux/test/layers/basic_tests.jl b/GNNLux/test/layers/basic_tests.jl index 405359af8..9f59f3b10 100644 --- a/GNNLux/test/layers/basic_tests.jl +++ b/GNNLux/test/layers/basic_tests.jl @@ -1,6 +1,4 @@ @testitem "layers/basic" setup=[SharedTestSetup] begin - @test 1==1 - """ rng = StableRNG(17) g = rand_graph(10, 40, seed=17) x = randn(rng, Float32, 3, 10) @@ -18,5 +16,4 @@ c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3)) test_lux_layer(rng, c, g, x, outputsize=(3,), container=true) end - """ end From ff2670c01b870d770bb741ab8af83c6bd0a831e6 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:44:36 +0530 Subject: [PATCH 19/23] Update conv_tests.jl: Fixing tests --- GNNLux/test/layers/conv_tests.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index ca1ed68d6..f24730777 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -95,14 +95,16 @@ end """ + @testset "MEGNetConv" begin l = MEGNetConv(in_dims => out_dims) - + ps = LuxCore.initialparameters(rng, l) st = LuxCore.initialstates(rng, l) - - (x_new, e_new), st_new = l(g, x, ps, st) - + + e = randn(rng, Float32, in_dims, g.num_edges) + (x_new, e_new), st_new = l(g, x, e, ps, st) + @test size(x_new) == (out_dims, g.num_nodes) @test size(e_new) == (out_dims, g.num_edges) end From ce2c1b65ae5c96a1bd424e4a5d55484907a4770b Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:55:49 +0530 Subject: [PATCH 20/23] Update conv.jl: Back to old commit --- GNNlib/src/layers/conv.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 2c13e62fc..0b7dd2499 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -355,18 +355,18 @@ end ####################### MegNetConv ###################################### -function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, Nothing}=nothing) +function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) check_num_nodes(g, x) - - ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e + + ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e l.ϕe(vcat(xi, xj, e)) end - xᵉ = aggregate_neighbors(g, l.aggr, ē) + xᵉ = aggregate_neighbors(g, l.aggr, ē) x̄ = l.ϕv(vcat(x, xᵉ)) - return x̄, ē + return x̄, ē end ####################### GMMConv ###################################### @@ -721,4 +721,4 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix) T1_out = T2_out end return h .+ l.bias -end \ No newline at end of file +end From 032e6169b35c4926aa99611ed2efa9dc090f56e1 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:56:37 +0530 Subject: [PATCH 21/23] Update conv_tests.jl: Fix tests --- GNNLux/test/layers/conv_tests.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index f24730777..bef265040 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -4,7 +4,7 @@ in_dims = 3 out_dims = 5 x = randn(rng, Float32, in_dims, 10) - """ + @testset "GCNConv" begin l = GCNConv(in_dims => out_dims, relu) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) @@ -93,9 +93,7 @@ l = GINConv(nn, 0.5) test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) end - """ - @testset "MEGNetConv" begin l = MEGNetConv(in_dims => out_dims) From b86573f9e9fd045e3f0eb7331f9480a606e121a1 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:57:06 +0530 Subject: [PATCH 22/23] Update conv_tests.jl --- GNNLux/test/layers/conv_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index bef265040..877e6e90b 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -53,7 +53,7 @@ @test size(hnew) == (hout, g.num_nodes) @test size(xnew) == (in_dims, g.num_nodes) end - + @testset "GATConv" begin x = randn(rng, Float32, 6, 10) From 0957a2542eda7105faacbf339bc390480c4aa285 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:57:19 +0530 Subject: [PATCH 23/23] Update conv.jl