Skip to content

Commit 0f8e13c

Browse files
authored
Add possibility to pass weights to GCNConv (#447)
* Modify GCNConv * Add tests * Update src/layers/conv.jl ---------
1 parent df56b7e commit 0f8e13c

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

src/layers/conv.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@ and optionally an edge weight vector.
3232
3333
# Forward
3434
35-
(::GCNConv)(g::GNNGraph, x::AbstractMatrix, edge_weight = nothing, norm_fn::Function = d -> 1 ./ sqrt.(d)) -> AbstractMatrix
35+
(::GCNConv)(g::GNNGraph, x::AbstractMatrix, edge_weight = nothing, norm_fn::Function = d -> 1 ./ sqrt.(d), conv_weight::Union{Nothing,AbstractMatrix} = nothing) -> AbstractMatrix
3636
37-
Takes as input a graph `g`,ca node feature matrix `x` of size `[in, num_nodes]`,
37+
Takes as input a graph `g`, a node feature matrix `x` of size `[in, num_nodes]`,
3838
and optionally an edge weight vector. Returns a node feature matrix of size
3939
`[out, num_nodes]`.
4040
4141
The `norm_fn` parameter allows for custom normalization of the graph convolution operation by passing a function as argument.
4242
By default, it computes ``\frac{1}{\sqrt{d}}`` i.e the inverse square root of the degree (`d`) of each node in the graph.
43+
If `conv_weight` is an `AbstractMatrix` of size `[out, in]`, then the convolution is performed using that weight matrix instead of the weights stored in the model.
4344
4445
# Examples
4546
@@ -102,11 +103,21 @@ check_gcnconv_input(g::AbstractGNNGraph, edge_weight::Nothing) = nothing
102103
function (l::GCNConv)(g::AbstractGNNGraph,
103104
x,
104105
edge_weight::EW = nothing,
105-
norm_fn::Function = d -> 1 ./ sqrt.(d)
106+
norm_fn::Function = d -> 1 ./ sqrt.(d);
107+
conv_weight::Union{Nothing,AbstractMatrix} = nothing
106108
) where {EW <: Union{Nothing, AbstractVector}}
107109

108110
check_gcnconv_input(g, edge_weight)
109111

112+
if conv_weight === nothing
113+
weight = l.weight
114+
else
115+
weight = conv_weight
116+
if size(weight) != size(l.weight)
117+
throw(ArgumentError("The weight matrix has the wrong size. Expected $(size(l.weight)) but got $(size(weight))"))
118+
end
119+
end
120+
110121
if l.add_self_loops
111122
g = add_self_loops(g)
112123
if edge_weight !== nothing
@@ -116,11 +127,11 @@ function (l::GCNConv)(g::AbstractGNNGraph,
116127
@assert length(edge_weight) == g.num_edges
117128
end
118129
end
119-
Dout, Din = size(l.weight)
130+
Dout, Din = size(weight)
120131
if Dout < Din && !(g isa GNNHeteroGraph)
121132
# multiply before convolution if it is more convenient, otherwise multiply after
122133
# (this works only for homogenous graph)
123-
x = l.weight * x
134+
x = weight * x
124135
end
125136

126137
xj, xi = expand_srcdst(g, x) # expand only after potential multiplication
@@ -150,7 +161,7 @@ function (l::GCNConv)(g::AbstractGNNGraph,
150161
end
151162
x = x .* cin'
152163
if Dout >= Din || g isa GNNHeteroGraph
153-
x = l.weight * x
164+
x = weight * x
154165
end
155166
return l.σ.(x .+ l.bias)
156167
end

test/layers/conv.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ test_graphs = [g1, g_single_vertex]
6464
@test gradient(w -> sum(l(g, x, w)), w)[1] isa AbstractVector{T} # redundant test but more explicit
6565
test_layer(l, g, rtol = RTOL_HIGH, outsize = (1, g.num_nodes), test_gpu = false)
6666
end
67+
68+
@testset "conv_weight" begin
69+
l = GraphNeuralNetworks.GCNConv(in_channel => out_channel)
70+
w = zeros(T, out_channel, in_channel)
71+
g1 = GNNGraph(adj1, ndata = ones(T, in_channel, N))
72+
@test l(g1, g1.ndata.x, conv_weight = w) == zeros(T, out_channel, N)
73+
a = rand(T, in_channel, N)
74+
g2 = GNNGraph(adj1, ndata = a)
75+
@test l(g2, g2.ndata.x, conv_weight = w) == w * a
76+
77+
end
6778
end
6879

6980
@testset "ChebConv" begin

0 commit comments

Comments
 (0)