Skip to content

Commit 510e430

Browse files
committed
Merge branch 'add-GMMConv' of https://github.com/melioristic/GraphNeuralNetworks.jl into add-GMMConv
2 parents 4928894 + f815b2a commit 510e430

File tree

2 files changed

+52
-35
lines changed

2 files changed

+52
-35
lines changed

src/layers/conv.jl

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,28 +1066,31 @@ function (l::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
10661066
end
10671067

10681068
@doc raw"""
1069-
GMMConv(in => out, K, e_dim, σ=identity; [init, bias])
1069+
GMMConv((in, ein) => out, σ=identity; K=1, bias=true, init=glorot_uniform, residual=false)
10701070
Graph mixture model convolution layer from the paper [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402)
10711071
Performs the operation
10721072
```math
10731073
\mathbf{x}_i' = \frac{1}{|N(i)|} \sum_{j\in N(i)}\frac{1}{K}\sum_{k=1}^k \mathbf{w}_k(\mathbf{e}_{j\to i}) \odot \Theta_k \mathbf{x}_j
10741074
```
10751075
where
10761076
```math
1077-
w^a_{k}(e^a) = \exp(\frac{-1}{2}(e^a - \mu^a_k)^T \Sigma^a_k^{-1}(e^a - \mu^a_k))
1077+
w^a_{k}(e^a) = \exp(\frac{-1}{2}(e^a - \mu^a_k)^T (\Sigma^{-1})^a_k(e^a - \mu^a_k))
10781078
```
1079-
$\Theta_k$, $\mu^a_k$, $\Sigma^a_k^{-1}$ are learnable parameters.
1079+
$\Theta_k$, $\mu^a_k$, $\Sigma^{-1})^a_k$ are learnable parameters.
10801080
10811081
The input to the layer is a node feature array 'X' of size `(num_features, num_nodes)` and
10821082
edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
1083+
10831084
# Arguments
1084-
- `in`: Number of input features.
1085+
1086+
- `in`: Number of input node features.
1087+
- `ein`: Number of input edge features.
10851088
- `out`: Number of output features.
1086-
- `K` : Number of kernels. Default `1`.
1087-
- `e_dim` : Dimensionality of pseudo coordinates. Has to correspond to the edge features dimension in the forward pass.
10881089
- `σ`: Activation function. Default `identity`.
1090+
- `K`: Number of kernels. Default `1`.
10891091
- `bias`: Add learnable bias. Default `true`.
10901092
- `init`: Weights' initializer. Default `glorot_uniform`.
1093+
- `residual`: Residual conncetion. Default `false`
10911094
10921095
#Examples
10931096
@@ -1096,12 +1099,12 @@ edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
10961099
s = [1,1,2,3]
10971100
t = [2,3,1,1]
10981101
g = GNNGraph(s,t)
1099-
in_feature, out_feature, K, e_dim = 4, 7, 8, 10
1100-
x = randn(in_feature, g.num_nodes)
1101-
e = randn(e_dim, g.num_edges)
1102+
nin, ein, out, K = 4, 10, 7, 8
1103+
x = randn(nin, g.num_nodes)
1104+
e = randn(ein, g.num_edges)
11021105
11031106
# create layer
1104-
l = GMMConv(in_feature=>out_feature, K, e_dim)
1107+
l = GMMConv((nin, ein) => out, K=K)
11051108
11061109
# forward pass
11071110
l(g, x, e)
@@ -1112,52 +1115,66 @@ struct GMMConv{A<:AbstractMatrix, B, F} <:GNNLayer
11121115
sigma_inv::A
11131116
bias::B
11141117
σ::F
1115-
ch::Pair{NTuple{2,Int}, Int}
1118+
ch::Pair{NTuple{2,Int},Int}
11161119
K::Int
11171120
dense_x::Dense
1121+
residual::Bool
11181122
end
11191123

11201124
@functor GMMConv
11211125

1122-
function GMMConv(ch::Pair{NTuple{2,Int}, Int},
1126+
function GMMConv(ch::Pair{NTuple{2,Int},Int},
11231127
σ=identity;
11241128
K::Int=1,
1129+
bias::Bool=true,
11251130
init=Flux.glorot_uniform,
1126-
bias::Bool=true)
1131+
residual=false)
1132+
11271133
(nin, ein), out = ch
11281134
mu = init(ein, K)
11291135
sigma_inv = init(K, ein)
11301136
b = bias ? Flux.create_bias(ones(out), true) : false
1131-
dense_x = Dense(in, out*K, bias=false)
1132-
GMMConv(mu, sigma_inv, b, σ, ch, K, dense_x)
1137+
dense_x = Dense(nin, out*K, bias=false)
1138+
GMMConv(mu, sigma_inv, b, σ, ch, K, dense_x, residual)
11331139
end
11341140

1135-
function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, u::AbstractMatrix)
1141+
function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
1142+
(nin, ein), out = l.ch #Notational Simplicity
11361143

1137-
@assert (l.ein == size(u)[1] && g.num_edges == size(u)[2]) "Pseudo-cordinate dim $(size(u)) does not match (ein=$(ein),num_edge=$(g.num_edges))"
1144+
@assert (ein == size(e)[1] && g.num_edges == size(e)[2]) "Pseudo-cordinate dim $(size(u)) does not match (ein=$(ein),num_edge=$(g.num_edges))"
11381145

11391146
num_edges = g.num_edges
11401147
d = degree(g, dir=:in)
1141-
u = reshape(u, (l.ein, 1, num_edges))
1142-
mu = reshape(l.mu, (l.ein, l.K, 1))
1148+
w = reshape(e, (ein, 1, num_edges))
1149+
mu = reshape(l.mu, (ein, l.K, 1))
11431150

1144-
e = -0.5*(u.-mu).^2
1145-
e = e .* ((reshape(l.sigma_inv, (l.ein, l.K, 1)).^2) )
1146-
e = exp.(sum(e, dims = 1 )) # (1, K, num_edge)
1147-
1148-
xj = reshape(l.dense_x(x), (l.ch[2],l.K,:)) # (out, K, num_nodes)
1149-
x = propagate(e_mul_xj, g, +, xj=xj, e=e)
1150-
x = dropdims(mean(x, dims=2), dims=2) # (out, num_nodes)
1151-
x = 1 / d .* x
1151+
w = -0.5*(w.-mu).^2
1152+
w = w .* reshape(l.sigma_inv, (ein, l.K, 1))
1153+
w = exp.(sum(w, dims = 1 )) # (1, K, num_edge)
1154+
1155+
xj = reshape(l.dense_x(x), (out,l.K,:)) # (out, K, num_nodes)
1156+
m = propagate(e_mul_xj, g, +, xj=xj, e=w)
1157+
m = dropdims(mean(m, dims=2), dims=2) # (out, num_nodes)
1158+
m = 1 / d .* m
1159+
1160+
m = l.σ(m .+ l.bias)
1161+
1162+
if l.residual
1163+
if size(x,1)==size(m,1)
1164+
m+=x
1165+
else
1166+
@warn "Residual not applied : output feature $(size(m,1)) !== input_feature $(size(x,1))"
1167+
end
1168+
end
11521169

1153-
return l.σ(x .+ l.bias)
1170+
return m
11541171
end
11551172

11561173
function Base.show(io::IO, l::GMMConv)
1157-
in, out, K, ein = l.ch[1], l.ch[2], l.K, l.ein
1158-
print(io, "GMMConv(", in, " => ", out)
1159-
print(io, ", K=", K)
1160-
print(io, ", ein=", ein)
1174+
(nin, ein), out = l.ch
1175+
print(io, "GMMConv((", nin, ",", ein, ")=>", out)
1176+
print(io, ", K=", l.K)
1177+
print(io, ", σ=", l.σ)
11611178
print(io, ")")
11621179

11631180
end

test/layers/conv.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,11 @@
264264
end
265265

266266
@testset "GMMConv" begin
267-
e_dim = 10
267+
ein_channel = 10
268268
K = 5
269-
l = GMMConv(in_channel => out_channel, K=K, e_dim=e_dim)
269+
l = GMMConv((in_channel, ein_channel )=> out_channel, K=K)
270270
for g in test_graphs
271-
g = GNNGraph(g, edata=rand(e_dim, g.num_edges))
271+
g = GNNGraph(g, edata=rand(ein_channel, g.num_edges))
272272
test_layer(l, g, rtol=RTOL_HIGH, outsize = (out_channel, g.num_nodes))
273273
end
274274
end

0 commit comments

Comments
 (0)