Skip to content

Commit 4a2c609

Browse files
committed
n_kernel is now k, default added, test added
1 parent 40871c4 commit 4a2c609

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

src/layers/conv.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ function (l::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
10661066
end
10671067

10681068
@doc raw"""
1069-
GMMConv(in => out, n_kernel, e_dim, σ=identity; [init, bias])
1069+
GMMConv(in => out, K, e_dim, σ=identity; [init, bias])
10701070
Graph mixture model convolution layer from the paper [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402)
10711071
Performs the operation
10721072
```math
@@ -1083,7 +1083,7 @@ edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
10831083
# Arguments
10841084
- `in`: Number of input features.
10851085
- `out`: Number of output features.
1086-
- `n_kernel` : Number of kernels.
1086+
- `K` : Number of kernels.
10871087
- `e_dim` : Dimensionality of pseudo coordinates.
10881088
- `σ`: Activation function. Default `identity`.
10891089
- `bias`: Add learnable bias. Default `true`.
@@ -1096,12 +1096,12 @@ edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
10961096
s = [1,1,2,3]
10971097
t = [2,3,1,1]
10981098
g = GNNGraph(s,t)
1099-
in_feature, out_feature, n_k, e_dim = 4, 7, 8, 10
1099+
in_feature, out_feature, K, e_dim = 4, 7, 8, 10
11001100
x = randn(in_feature, g.num_nodes)
11011101
e = randn(e_dim, g.num_edges)
11021102
11031103
# create layer
1104-
l = GMMConv(in_feature=>out_feature, n_k, e_dim)
1104+
l = GMMConv(in_feature=>out_feature, K, e_dim)
11051105
11061106
# forward pass
11071107
l(g, x, e)
@@ -1114,42 +1114,41 @@ struct GMMConv{A<:AbstractMatrix, B, F} <:GNNLayer
11141114
bias::B
11151115
σ::F
11161116
ch::Pair{Int, Int}
1117-
n_kernel::Int
1117+
K::Int
11181118
e_dim::Int
11191119
dense_x::Dense
11201120
end
11211121

1122-
Flux.@functor GMMConv
1122+
@functor GMMConv
11231123

11241124
function GMMConv(ch::Pair{Int, Int},
1125-
n_kernel::Int,
1126-
e_dim::Int,
11271125
σ=identity;
1126+
K::Int=1,
1127+
e_dim::Int=1,
11281128
init=Flux.glorot_uniform,
11291129
bias::Bool=true)
11301130
in, out = ch
1131-
mu = init(n_kernel, e_dim)
1132-
sigma_inv = init(n_kernel, e_dim)
1131+
mu = init(K, e_dim)
1132+
sigma_inv = init(K, e_dim)
11331133
b = bias ? Flux.create_bias(ones(out), true) : false
1134-
dense_x = Dense(in, out*n_kernel, bias=false)
1135-
GMMConv(mu, sigma_inv, b, σ, ch, n_kernel, e_dim, dense_x)
1134+
dense_x = Dense(in, out*K, bias=false)
1135+
GMMConv(mu, sigma_inv, b, σ, ch, K, e_dim, dense_x)
11361136
end
11371137

11381138
function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, u::AbstractMatrix)
11391139

1140-
11411140
@assert (l.e_dim == size(u)[1] && g.num_edges == size(u)[2]) "Pseudo-cordinate dim $(size(u)) does not match (e_dim=$(e_dim),num_edge=$(g.num_edges))"
11421141

11431142
num_edges = g.num_edges
11441143
d = degree(g, dir=:in)
11451144
u = reshape(u, (l.e_dim, 1, num_edges))
1146-
mu = reshape(l.mu, (l.e_dim, l.n_kernel, 1))
1145+
mu = reshape(l.mu, (l.e_dim, l.K, 1))
11471146

11481147
e = -0.5*(u.-mu).^2
1149-
e = e .* ((reshape(l.sigma_inv, (l.e_dim, l.n_kernel, 1)).^2) )
1150-
e = exp.(sum(e, dims = 1 )) # (1, n_kernel, num_edge)
1148+
e = e .* ((reshape(l.sigma_inv, (l.e_dim, l.K, 1)).^2) )
1149+
e = exp.(sum(e, dims = 1 )) # (1, K, num_edge)
11511150

1152-
xj = reshape(l.dense_x(x), (l.ch[2],l.n_kernel,:)) # (out, n_kernel, num_nodes)
1151+
xj = reshape(l.dense_x(x), (l.ch[2],l.K,:)) # (out, K, num_nodes)
11531152
x = propagate(e_mul_xj, g, +, xj=xj, e=e)
11541153
x = dropdims(mean(x, dims=2), dims=2) # (out, num_nodes)
11551154
x = 1 / d .* x
@@ -1158,10 +1157,10 @@ function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, u::AbstractMatrix)
11581157
end
11591158

11601159
function Base.show(io::IO, l::GMMConv)
1161-
in, out, n_kernel, e_dim = l.ch[1], l.ch[2], l.n_kernel, l.e_dim
1160+
in, out, K, e_dim = l.ch[1], l.ch[2], l.K, l.e_dim
11621161
print(io, "GMMConv(", in, " => ", out)
1163-
print(io, ", n_kernel= ", n_kernel)
1164-
print(io, ", e_dim = ", e_dim)
1162+
print(io, ", K=", K)
1163+
print(io, ", e_dim=", e_dim)
11651164
print(io, ")")
11661165

11671166
end

test/layers/conv.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,4 +262,14 @@
262262
outsize=((out_channel, g.num_nodes), (out_channel, g.num_edges)))
263263
end
264264
end
265+
266+
@testset "GMMConv" begin
267+
e_dim = 10
268+
K = 5
269+
l = GMMConv(in_channel => out_channel, K=K, e_dim=e_dim)
270+
for g in test_graphs
271+
g = GNNGraph(g, edata=rand(e_dim, g.num_edges))
272+
test_layer(l, g, rtol=RTOL_HIGH, outsize = (out_channel, g.num_edges))
273+
end
274+
end
265275
end

0 commit comments

Comments
 (0)