Skip to content

Commit 8eddb5f

Browse files
Merge pull request #147 from melioristic/add-GMMConv
Added GMMConv
2 parents 04d026b + 65488a1 commit 8eddb5f

File tree

3 files changed

+125
-1
lines changed

3 files changed

+125
-1
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ export
6060
NNConv,
6161
ResGatedGraphConv,
6262
SAGEConv,
63-
63+
GMMConv,
6464

6565
# layers/pool
6666
GlobalPool,

src/layers/conv.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,4 +1065,118 @@ function (l::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
10651065
return x̄, ē
10661066
end
10671067

1068+
@doc raw"""
1069+
GMMConv((in, ein) => out, σ=identity; K=1, bias=true, init=glorot_uniform, residual=false)
1070+
1071+
Graph mixture model convolution layer from the paper [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402)
1072+
Performs the operation
1073+
```math
1074+
\mathbf{x}_i' = \frac{1}{|N(i)|} \sum_{j\in N(i)}\frac{1}{K}\sum_{k=1}^k \mathbf{w}_k(\mathbf{e}_{j\to i}) \odot \Theta_k \mathbf{x}_j
1075+
```
1076+
where
1077+
```math
1078+
w^a_{k}(e^a) = \exp(\frac{-1}{2}(e^a - \mu^a_k)^T (\Sigma^{-1})^a_k(e^a - \mu^a_k))
1079+
```
1080+
$\Theta_k$, $\mu^a_k$, $\Sigma^{-1})^a_k$ are learnable parameters.
1081+
1082+
The input to the layer is a node feature array 'X' of size `(num_features, num_nodes)` and
1083+
edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
1084+
1085+
# Arguments
1086+
1087+
- `in`: Number of input node features.
1088+
- `ein`: Number of input edge features.
1089+
- `out`: Number of output features.
1090+
- `σ`: Activation function. Default `identity`.
1091+
- `K`: Number of kernels. Default `1`.
1092+
- `bias`: Add learnable bias. Default `true`.
1093+
- `init`: Weights' initializer. Default `glorot_uniform`.
1094+
- `residual`: Residual conncetion. Default `false`.
1095+
1096+
#Examples
1097+
1098+
```julia
1099+
# create data
1100+
s = [1,1,2,3]
1101+
t = [2,3,1,1]
1102+
g = GNNGraph(s,t)
1103+
nin, ein, out, K = 4, 10, 7, 8
1104+
x = randn(Float32, nin, g.num_nodes)
1105+
e = randn(Float32, ein, g.num_edges)
1106+
1107+
# create layer
1108+
l = GMMConv((nin, ein) => out, K=K)
1109+
1110+
# forward pass
1111+
l(g, x, e)
1112+
```
1113+
"""
1114+
struct GMMConv{A<:AbstractMatrix, B, F} <:GNNLayer
1115+
mu::A
1116+
sigma_inv::A
1117+
bias::B
1118+
σ::F
1119+
ch::Pair{NTuple{2,Int},Int}
1120+
K::Int
1121+
dense_x::Dense
1122+
residual::Bool
1123+
end
1124+
1125+
@functor GMMConv
1126+
1127+
function GMMConv(ch::Pair{NTuple{2,Int},Int},
1128+
σ=identity;
1129+
K::Int=1,
1130+
bias::Bool=true,
1131+
init=Flux.glorot_uniform,
1132+
residual=false)
1133+
1134+
(nin, ein), out = ch
1135+
mu = init(ein, K)
1136+
sigma_inv = init(ein, K)
1137+
b = bias ? Flux.create_bias(mu, true, out) : false
1138+
dense_x = Dense(nin, out*K, bias=false)
1139+
GMMConv(mu, sigma_inv, b, σ, ch, K, dense_x, residual)
1140+
end
1141+
1142+
function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
1143+
(nin, ein), out = l.ch #Notational Simplicity
1144+
1145+
@assert (ein == size(e)[1] && g.num_edges == size(e)[2]) "Pseudo-cordinate dimension is not equal to (ein,num_edge)"
1146+
1147+
num_edges = g.num_edges
1148+
w = reshape(e, (ein, 1, num_edges))
1149+
mu = reshape(l.mu, (ein, l.K, 1))
1150+
1151+
w = @. ((w - mu)^2) / 2
1152+
w = w .* reshape(l.sigma_inv.^2, (ein, l.K, 1))
1153+
w = exp.(sum(w, dims = 1 )) # (1, K, num_edge)
1154+
1155+
xj = reshape(l.dense_x(x), (out, l.K, :)) # (out, K, num_nodes)
10681156

1157+
m = propagate(e_mul_xj, g, mean, xj=xj, e=w)
1158+
m = dropdims(mean(m, dims=2), dims=2) # (out, num_nodes)
1159+
1160+
m = l.σ(m .+ l.bias)
1161+
1162+
if l.residual
1163+
if size(x, 1) == size(m, 1)
1164+
m += x
1165+
else
1166+
@warn "Residual not applied : output feature is not equal to input_feature"
1167+
end
1168+
end
1169+
1170+
return m
1171+
end
1172+
1173+
(l::GMMConv)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g), edge_features(g)))
1174+
1175+
function Base.show(io::IO, l::GMMConv)
1176+
(nin, ein), out = l.ch
1177+
print(io, "GMMConv((", nin, ",", ein, ")=>", out)
1178+
l.σ == identity || print(io, ", σ=", l.dense_s.σ)
1179+
print(io, ", K=", l.K)
1180+
l.residual==true || print(io, ", residual=", l.residual)
1181+
print(io, ")")
1182+
end

test/layers/conv.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,4 +262,14 @@
262262
outsize=((out_channel, g.num_nodes), (out_channel, g.num_edges)))
263263
end
264264
end
265+
266+
@testset "GMMConv" begin
267+
ein_channel = 10
268+
K = 5
269+
l = GMMConv((in_channel, ein_channel )=> out_channel, K=K)
270+
for g in test_graphs
271+
g = GNNGraph(g, edata=rand(Float32, ein_channel, g.num_edges))
272+
test_layer(l, g, rtol=RTOL_HIGH, outsize = (out_channel, g.num_nodes))
273+
end
274+
end
265275
end

0 commit comments

Comments
 (0)