@@ -1142,29 +1142,31 @@ end
1142
1142
function (l:: GMMConv )(g:: GNNGraph , x:: AbstractMatrix , e:: AbstractMatrix )
1143
1143
(nin, ein), out = l. ch # Notational Simplicity
1144
1144
1145
- @assert (ein == size (e)[1 ] && g. num_edges == size (e)[2 ]) " Pseudo-cordinate dim $( size (e)) does not match (ein= $(ein) ,num_edge= $(g . num_edges) )"
1145
+ @assert (ein == size (e)[1 ] && g. num_edges == size (e)[2 ]) " Pseudo-cordinate dimension is not equal to (ein,num_edge)"
1146
1146
1147
1147
num_edges = g. num_edges
1148
1148
d = degree (g, dir= :in )
1149
+ print (eltype (d))
1149
1150
w = reshape (e, (ein, 1 , num_edges))
1150
1151
mu = reshape (l. mu, (ein, l. K, 1 ))
1151
1152
1152
- w = @. - 0.5 * ( w - mu)^ 2
1153
+ w = @. (( w - mu)^ 2 ) / 2
1153
1154
w = w .* reshape (l. sigma_inv, (ein, l. K, 1 ))
1154
1155
w = exp .(sum (w, dims = 1 )) # (1, K, num_edge)
1155
1156
1156
1157
xj = reshape (l. dense_x (x), (out, l. K, :)) # (out, K, num_nodes)
1158
+
1157
1159
m = propagate (e_mul_xj, g, + , xj= xj, e= w)
1158
1160
m = dropdims (mean (m, dims= 2 ), dims= 2 ) # (out, num_nodes)
1159
- m = 1 / d .* m
1160
-
1161
+ m = m ./ reshape (d, ( 1 , g . num_nodes))
1162
+
1161
1163
m = l. σ (m .+ l. bias)
1162
1164
1163
1165
if l. residual
1164
1166
if size (x, 1 ) == size (m, 1 )
1165
1167
m += x
1166
1168
else
1167
- @warn " Residual not applied : output feature $( size (m, 1 )) !== input_feature $( size (x, 1 )) "
1169
+ @warn " Residual not applied : output feature is not equal to input_feature "
1168
1170
end
1169
1171
end
1170
1172
@@ -1180,4 +1182,4 @@ function Base.show(io::IO, l::GMMConv)
1180
1182
print (io, " , σ=" , l. σ)
1181
1183
print (io, " )" )
1182
1184
1183
- end
1185
+ end
0 commit comments