Skip to content

Commit 786b200

Browse files
committed
WIP
1 parent 83b6b7e commit 786b200

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,3 +628,74 @@ function Base.show(io::IO, l::GINConv)
628628
print(io, ", $(l.ϵ)")
629629
print(io, ")")
630630
end
631+
632+
@concrete struct NNConv <: GNNContainerLayer{(:nn,)}
633+
nn <: AbstractExplicitLayer
634+
aggr
635+
in_dims::Int
636+
out_dims::Int
637+
use_bias::Bool
638+
add_self_loops::Bool
639+
use_edge_weight::Bool
640+
init_weight
641+
init_bias
642+
σ
643+
end
644+
645+
"""
646+
function NNConv(ch::Pair{Int, Int}, σ = identity;
647+
init_weight = glorot_uniform,
648+
init_bias = zeros32,
649+
use_bias::Bool = true,
650+
add_self_loops::Bool = true,
651+
use_edge_weight::Bool = false,
652+
allow_fast_activation::Bool = true)
653+
"""
654+
# fix args order
655+
function NNConv(ch::Pair{Int, Int}, nn, σ = identity;
656+
aggr = +,
657+
init_bias = zeros32,
658+
use_bias::Bool = true,
659+
init_weight = glorot_uniform,
660+
add_self_loops::Bool = true,
661+
use_edge_weight::Bool = false,
662+
allow_fast_activation::Bool = true)
663+
in_dims, out_dims = ch
664+
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
665+
return NNConv(nn, aggr, in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ)
666+
end
667+
668+
function (l::GCNConv)(g, x, edge_weight, ps, st)
669+
nn = StatefulLuxLayer{true}(l.nn, ps, st)
670+
671+
# what would be the order of args here?
672+
m = (; nn, l.aggr, ps.weight, bias = _getbias(ps),
673+
l.add_self_loops, l.use_edge_weight, l.σ)
674+
y = GNNlib.nn_conv(m, g, x, edge_weight)
675+
stnew = _getstate(nn)
676+
return y, stnew
677+
end
678+
679+
function LuxCore.initialparameters(rng::AbstractRNG, l::NNConv)
680+
weight = l.init_weight(rng, l.out_dims, l.in_dims)
681+
if l.use_bias
682+
bias = l.init_bias(rng, l.out_dims)
683+
return (; weight, bias)
684+
else
685+
return (; weight)
686+
end
687+
end
688+
689+
LuxCore.parameterlength(l::NNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims # nn wont affect this right?
690+
LuxCore.outputsize(d::NNConv) = (d.out_dims,)
691+
692+
693+
function Base.show(io::IO, l::GINConv)
694+
print(io, "NNConv($(l.nn)")
695+
print(io, ", $(l.ϵ)")
696+
l.σ == identity || print(io, ", ", l.σ)
697+
l.use_bias || print(io, ", use_bias=false")
698+
l.add_self_loops || print(io, ", add_self_loops=false")
699+
!l.use_edge_weight || print(io, ", use_edge_weight=true")
700+
print(io, ")")
701+
end

0 commit comments

Comments
 (0)