Skip to content

Commit 2e6a392

Browse files
Update conv.jl
learnable param in doc, num_edge in the last dim (remove permutedims)
1 parent 62227b3 commit 2e6a392

File tree

1 file changed

+24
-34
lines changed

1 file changed

+24
-34
lines changed

src/layers/conv.jl

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,53 +1066,45 @@ function (l::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
10661066
end
10671067

10681068
@doc raw"""
1069-
GMMConv(in => out, n_kernel, u_dim, σ=identity; [init, bias])
1070-
1069+
GMMConv(in => out, n_kernel, e_dim, σ=identity; [init, bias])
10711070
Graph mixture model convolution layer from the paper [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402)
1072-
10731071
Performs the operation
10741072
```math
10751073
\mathbf{x}_i' = \frac{1}{|N(i)|} \sum_{j\in N(i)}\frac{1}{K}\sum_{k=1}^k \mathbf{w}_k(\mathbf{e}_{j\to i}) \odot \Theta_k \mathbf{x}_j
10761074
```
1077-
10781075
where
10791076
```math
10801077
w^a_{k}(e^a) = \exp(\frac{-1}{2}(e^a - \mu^a_k)^T \Sigma^a_k^{-1}(e^a - \mu^a_k))
10811078
```
1079+
$\Theta_k$, $\mu^a_k$, $\Sigma^a_k^{-1}$ are learnable parameters.
10821080
10831081
The input to the layer is a node feature array 'X' of size `(num_features, num_nodes)` and
10841082
edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
1085-
10861083
# Arguments
1087-
10881084
- `in`: Number of input features.
10891085
- `out`: Number of output features.
10901086
- `n_kernel` : Number of kernels.
1091-
- `u_dim` : Dimensionality of pseudo coordinates.
1087+
- `e_dim` : Dimensionality of pseudo coordinates.
10921088
- `σ`: Activation function. Default `identity`.
10931089
- `bias`: Add learnable bias. Default `true`.
10941090
- `init`: Weights' initializer. Default `glorot_uniform`.
10951091
10961092
#Examples
10971093
10981094
```julia
1099-
11001095
# create data
11011096
s = [1,1,2,3]
11021097
t = [2,3,1,1]
11031098
g = GNNGraph(s,t)
1104-
1105-
in_feature, out_feature, n_k, u_dim = 4, 7, 8, 10
1106-
1099+
in_feature, out_feature, n_k, e_dim = 4, 7, 8, 10
11071100
x = randn(in_feature, g.num_nodes)
1108-
u = randn(u_dim, g.num_edges)
1101+
u = randn(e_dim, g.num_edges)
11091102
11101103
# create layer
1111-
l = GMMConv(in_feature=>out_feature, n_k, u_dim)
1104+
l = GMMConv(in_feature=>out_feature, n_k, e_dim)
11121105
11131106
# forward pass
11141107
l(g, x, u)
1115-
11161108
```
11171109
"""
11181110

@@ -1123,55 +1115,53 @@ struct GMMConv{A<:AbstractMatrix, B, F} <:GNNLayer
11231115
σ::F
11241116
ch::Pair{Int, Int}
11251117
n_kernel::Int
1126-
u_dim::Int
1118+
e_dim::Int
11271119
dense_x::Dense
11281120
end
11291121

11301122
Flux.@functor GMMConv
11311123

11321124
function GMMConv(ch::Pair{Int, Int},
11331125
n_kernel::Int,
1134-
u_dim::Int,
1126+
e_dim::Int,
11351127
σ=identity;
11361128
init=Flux.glorot_uniform,
11371129
bias::Bool=true)
11381130
in, out = ch
1139-
mu = init(n_kernel, u_dim)
1140-
sigma_inv = init(n_kernel, u_dim)
1131+
mu = init(n_kernel, e_dim)
1132+
sigma_inv = init(n_kernel, e_dim)
11411133
b = bias ? Flux.create_bias(ones(out), true) : false
11421134
dense_x = Dense(in, out*n_kernel, bias=false)
1143-
GMMConv(mu, sigma_inv, b, σ, ch, n_kernel, u_dim, dense_x)
1135+
GMMConv(mu, sigma_inv, b, σ, ch, n_kernel, e_dim, dense_x)
11441136
end
11451137

11461138
function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, u::AbstractMatrix)
11471139

11481140

1149-
@assert (l.u_dim == size(u)[1] && g.num_edges == size(u)[2]) "Pseudo-cordinate dim $(size(u)) does not match (u_dim=$(u_dim),num_edge=$(g.num_edges))"
1141+
@assert (l.e_dim == size(u)[1] && g.num_edges == size(u)[2]) "Pseudo-cordinate dim $(size(u)) does not match (e_dim=$(e_dim),num_edge=$(g.num_edges))"
11501142

11511143
num_edges = g.num_edges
11521144
d = degree(g, dir=:in)
1153-
u = reshape(u, (num_edges, 1, l.u_dim))
1154-
mu = reshape(l.mu, (1, l.n_kernel, l.u_dim))
1155-
1156-
w = -0.5*(u.-mu).^2
1157-
w = w .* ((reshape(l.sigma_inv, (1, l.n_kernel, l.u_dim)).^2) )
1158-
w = exp.(sum(w, dims = 3 )) # n_edges, n_kernel, 1
1159-
w = permutedims(w, [3,2,1])
1160-
1161-
xj = reshape(l.dense_x(x), (l.ch[2],l.n_kernel,:))
1145+
u = reshape(u, (l.e_dim, 1, num_edges))
1146+
mu = reshape(l.mu, (l.e_dim, l.n_kernel, 1))
11621147

1163-
x = propagate(e_mul_xj, g, +, xj=xj, e=w)
1164-
x = dropdims(mean(x, dims=2), dims=2)
1165-
x = 1 / d .* x
1148+
e = -0.5*(u.-mu).^2
1149+
e = e .* ((reshape(l.sigma_inv, (l.e_dim, l.n_kernel, 1)).^2) )
1150+
e = exp.(sum(e, dims = 1 )) # (1, n_kernel, num_edge)
1151+
1152+
xj = reshape(l.dense_x(x), (l.ch[2],l.n_kernel,:)) # (out, n_kernel, num_nodes)
1153+
x = propagate(e_mul_xj, g, +, xj=xj, e=e)
1154+
x = dropdims(mean(x, dims=2), dims=2) # (out, num_nodes)
1155+
x = 1 / d .* x
11661156

11671157
return l.σ(x .+ l.bias)
11681158
end
11691159

11701160
function Base.show(io::IO, l::GMMConv)
1171-
in, out, n_kernel, u_dim = l.ch[1], l.ch[2], l.n_kernel, l.u_dim
1161+
in, out, n_kernel, e_dim = l.ch[1], l.ch[2], l.n_kernel, l.e_dim
11721162
print(io, "GMMConv(", in, " => ", out)
11731163
print(io, ", n_kernel= ", n_kernel)
1174-
print(io, ", pseudo-cordinate dimension = ", u_dim)
1164+
print(io, ", pseudo-cordinate dimension = ", e_dim)
11751165
print(io, ")")
11761166

11771167
end

0 commit comments

Comments
 (0)