Skip to content

Commit 4928894

Browse files
committed
e_dim to ein
1 parent 8541c5d commit 4928894

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ export
6060
NNConv,
6161
ResGatedGraphConv,
6262
SAGEConv,
63-
63+
GMMConv,
6464

6565
# layers/pool
6666
GlobalPool,

src/layers/conv.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,39 +1112,37 @@ struct GMMConv{A<:AbstractMatrix, B, F} <:GNNLayer
11121112
sigma_inv::A
11131113
bias::B
11141114
σ::F
1115-
ch::Pair{Int, Int}
1115+
ch::Pair{NTuple{2,Int}, Int}
11161116
K::Int
1117-
e_dim::Int
11181117
dense_x::Dense
11191118
end
11201119

11211120
@functor GMMConv
11221121

1123-
function GMMConv(ch::Pair{Int, Int},
1122+
function GMMConv(ch::Pair{NTuple{2,Int}, Int},
11241123
σ=identity;
11251124
K::Int=1,
1126-
e_dim::Int=1,
11271125
init=Flux.glorot_uniform,
11281126
bias::Bool=true)
1129-
in, out = ch
1130-
mu = init(K, e_dim)
1131-
sigma_inv = init(K, e_dim)
1127+
(nin, ein), out = ch
1128+
mu = init(ein, K)
1129+
sigma_inv = init(K, ein)
11321130
b = bias ? Flux.create_bias(ones(out), true) : false
11331131
dense_x = Dense(in, out*K, bias=false)
1134-
GMMConv(mu, sigma_inv, b, σ, ch, K, e_dim, dense_x)
1132+
GMMConv(mu, sigma_inv, b, σ, ch, K, dense_x)
11351133
end
11361134

11371135
function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, u::AbstractMatrix)
11381136

1139-
@assert (l.e_dim == size(u)[1] && g.num_edges == size(u)[2]) "Pseudo-cordinate dim $(size(u)) does not match (e_dim=$(e_dim),num_edge=$(g.num_edges))"
1137+
@assert (l.ein == size(u)[1] && g.num_edges == size(u)[2]) "Pseudo-cordinate dim $(size(u)) does not match (ein=$(ein),num_edge=$(g.num_edges))"
11401138

11411139
num_edges = g.num_edges
11421140
d = degree(g, dir=:in)
1143-
u = reshape(u, (l.e_dim, 1, num_edges))
1144-
mu = reshape(l.mu, (l.e_dim, l.K, 1))
1141+
u = reshape(u, (l.ein, 1, num_edges))
1142+
mu = reshape(l.mu, (l.ein, l.K, 1))
11451143

11461144
e = -0.5*(u.-mu).^2
1147-
e = e .* ((reshape(l.sigma_inv, (l.e_dim, l.K, 1)).^2) )
1145+
e = e .* ((reshape(l.sigma_inv, (l.ein, l.K, 1)).^2) )
11481146
e = exp.(sum(e, dims = 1 )) # (1, K, num_edge)
11491147

11501148
xj = reshape(l.dense_x(x), (l.ch[2],l.K,:)) # (out, K, num_nodes)
@@ -1156,10 +1154,10 @@ function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, u::AbstractMatrix)
11561154
end
11571155

11581156
function Base.show(io::IO, l::GMMConv)
1159-
in, out, K, e_dim = l.ch[1], l.ch[2], l.K, l.e_dim
1157+
in, out, K, ein = l.ch[1], l.ch[2], l.K, l.ein
11601158
print(io, "GMMConv(", in, " => ", out)
11611159
print(io, ", K=", K)
1162-
print(io, ", e_dim=", e_dim)
1160+
print(io, ", ein=", ein)
11631161
print(io, ")")
11641162

11651163
end

0 commit comments

Comments
 (0)