Skip to content

Commit 2863127

Browse files
committed
Merge branch 'add-GMMConv' of https://github.com/melioristic/GraphNeuralNetworks.jl into add-GMMConv
1 parent 5e589b5 commit 2863127

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/layers/conv.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,29 +1142,31 @@ end
11421142
function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
11431143
(nin, ein), out = l.ch #Notational Simplicity
11441144

1145-
@assert (ein == size(e)[1] && g.num_edges == size(e)[2]) "Pseudo-cordinate dim $(size(e)) does not match (ein=$(ein),num_edge=$(g.num_edges))"
1145+
@assert (ein == size(e)[1] && g.num_edges == size(e)[2]) "Pseudo-cordinate dimension is not equal to (ein,num_edge)"
11461146

11471147
num_edges = g.num_edges
11481148
d = degree(g, dir=:in)
1149+
print(eltype(d))
11491150
w = reshape(e, (ein, 1, num_edges))
11501151
mu = reshape(l.mu, (ein, l.K, 1))
11511152

1152-
w = @. -0.5 * (w - mu)^2
1153+
w = @. ((w - mu)^2) / 2
11531154
w = w .* reshape(l.sigma_inv, (ein, l.K, 1))
11541155
w = exp.(sum(w, dims = 1 )) # (1, K, num_edge)
11551156

11561157
xj = reshape(l.dense_x(x), (out, l.K, :)) # (out, K, num_nodes)
1158+
11571159
m = propagate(e_mul_xj, g, +, xj=xj, e=w)
11581160
m = dropdims(mean(m, dims=2), dims=2) # (out, num_nodes)
1159-
m = 1 / d .* m
1160-
1161+
m = m ./ reshape(d, (1, g.num_nodes))
1162+
11611163
m = l.σ(m .+ l.bias)
11621164

11631165
if l.residual
11641166
if size(x, 1) == size(m, 1)
11651167
m += x
11661168
else
1167-
@warn "Residual not applied : output feature $(size(m,1)) !== input_feature $(size(x,1))"
1169+
@warn "Residual not applied : output feature is not equal to input_feature"
11681170
end
11691171
end
11701172

@@ -1180,4 +1182,4 @@ function Base.show(io::IO, l::GMMConv)
11801182
print(io, ", σ=", l.σ)
11811183
print(io, ")")
11821184

1183-
end
1185+
end

0 commit comments

Comments
 (0)