Skip to content

Commit 0a23ffa

Browse files
authored
[GNNLux] Added SGConv (#475)
* added sgconv lux * fix * fix * fix * fix
1 parent c82efa0 commit 0a23ffa

File tree

4 files changed

+64
-3
lines changed

4 files changed

+64
-3
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ export AGNNConv,
2626
GCNConv,
2727
# GINConv,
2828
# GMMConv,
29-
GraphConv
29+
GraphConv,
3030
# MEGNetConv,
3131
# NNConv,
3232
# ResGatedGraphConv,
3333
# SAGEConv,
34-
# SGConv,
34+
SGConv
3535
# TAGConv,
3636
# TransformerConv
3737

GNNLux/src/layers/conv.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,4 +515,60 @@ function Base.show(io::IO, l::GATv2Conv)
515515
l.σ == identity || print(io, ", ", l.σ)
516516
print(io, ", negative_slope=", l.negative_slope)
517517
print(io, ")")
518+
end
519+
520+
@concrete struct SGConv <: GNNLayer
521+
in_dims::Int
522+
out_dims::Int
523+
k::Int
524+
use_bias::Bool
525+
add_self_loops::Bool
526+
use_edge_weight::Bool
527+
init_weight
528+
init_bias
529+
end
530+
531+
function SGConv(ch::Pair{Int, Int}, k = 1;
532+
init_weight = glorot_uniform,
533+
init_bias = zeros32,
534+
use_bias::Bool = true,
535+
add_self_loops::Bool = true,
536+
use_edge_weight::Bool = false)
537+
in_dims, out_dims = ch
538+
return SGConv(in_dims, out_dims, k, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias)
539+
end
540+
541+
function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv)
542+
weight = l.init_weight(rng, l.out_dims, l.in_dims)
543+
if l.use_bias
544+
bias = l.init_bias(rng, l.out_dims)
545+
return (; weight, bias)
546+
else
547+
return (; weight)
548+
end
549+
end
550+
551+
LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
552+
LuxCore.statelength(d::SGConv) = 0
553+
LuxCore.outputsize(d::SGConv) = (d.out_dims,)
554+
555+
function Base.show(io::IO, l::SGConv)
556+
print(io, "SGConv(", l.in_dims, " => ", l.out_dims)
557+
l.k || print(io, ", ", l.k)
558+
l.use_bias || print(io, ", use_bias=false")
559+
l.add_self_loops || print(io, ", add_self_loops=false")
560+
!l.use_edge_weight || print(io, ", use_edge_weight=true")
561+
print(io, ")")
562+
end
563+
564+
(l::SGConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing) =
565+
l(g, x, edge_weight, ps, st; conv_weight)
566+
567+
function (l::SGConv)(g, x, edge_weight, ps, st;
568+
conv_weight=nothing, )
569+
570+
m = (; ps.weight, bias = _getbias(ps),
571+
l.add_self_loops, l.use_edge_weight, l.k)
572+
y = GNNlib.sg_conv(m, g, x, edge_weight)
573+
return y, st
518574
end

GNNLux/test/layers/conv_tests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,10 @@
7777

7878
#TODO test edge
7979
end
80+
81+
@testset "SGConv" begin
82+
l = SGConv(in_dims => out_dims, 2)
83+
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
84+
end
8085
end
8186

GNNlib/src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,4 +722,4 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix)
722722
T1_out = T2_out
723723
end
724724
return h .+ l.bias
725-
end
725+
end

0 commit comments

Comments
 (0)