@@ -1066,53 +1066,45 @@ function (l::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
1066
1066
end
1067
1067
1068
1068
@doc raw """
1069
- GMMConv(in => out, n_kernel, u_dim, σ=identity; [init, bias])
1070
-
1069
+ GMMConv(in => out, n_kernel, e_dim, σ=identity; [init, bias])
1071
1070
Graph mixture model convolution layer from the paper [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402)
1072
-
1073
1071
Performs the operation
1074
1072
```math
1075
1073
\m athbf{x}_i' = \f rac{1}{|N(i)|} \s um_{j\i n N(i)}\f rac{1}{K}\s um_{k=1}^k \m athbf{w}_k(\m athbf{e}_{j\t o i}) \o dot \T heta_k \m athbf{x}_j
1076
1074
```
1077
-
1078
1075
where
1079
1076
```math
1080
1077
w^a_{k}(e^a) = \e xp(\f rac{-1}{2}(e^a - \m u^a_k)^T \S igma^a_k^{-1}(e^a - \m u^a_k))
1081
1078
```
1079
+ $\T heta_k$, $\m u^a_k$, $\S igma^a_k^{-1}$ are learnable parameters.
1082
1080
1083
1081
The input to the layer is a node feature array 'X' of size `(num_features, num_nodes)` and
1084
1082
edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
1085
-
1086
1083
# Arguments
1087
-
1088
1084
- `in`: Number of input features.
1089
1085
- `out`: Number of output features.
1090
1086
- `n_kernel` : Number of kernels.
1091
- - `u_dim ` : Dimensionality of pseudo coordinates.
1087
+ - `e_dim ` : Dimensionality of pseudo coordinates.
1092
1088
- `σ`: Activation function. Default `identity`.
1093
1089
- `bias`: Add learnable bias. Default `true`.
1094
1090
- `init`: Weights' initializer. Default `glorot_uniform`.
1095
1091
1096
1092
#Examples
1097
1093
1098
1094
```julia
1099
-
1100
1095
# create data
1101
1096
s = [1,1,2,3]
1102
1097
t = [2,3,1,1]
1103
1098
g = GNNGraph(s,t)
1104
-
1105
- in_feature, out_feature, n_k, u_dim = 4, 7, 8, 10
1106
-
1099
+ in_feature, out_feature, n_k, e_dim = 4, 7, 8, 10
1107
1100
x = randn(in_feature, g.num_nodes)
1108
- u = randn(u_dim , g.num_edges)
1101
+ u = randn(e_dim , g.num_edges)
1109
1102
1110
1103
# create layer
1111
- l = GMMConv(in_feature=>out_feature, n_k, u_dim )
1104
+ l = GMMConv(in_feature=>out_feature, n_k, e_dim )
1112
1105
1113
1106
# forward pass
1114
1107
l(g, x, u)
1115
-
1116
1108
```
1117
1109
"""
1118
1110
@@ -1123,55 +1115,53 @@ struct GMMConv{A<:AbstractMatrix, B, F} <:GNNLayer
1123
1115
σ:: F
1124
1116
ch:: Pair{Int, Int}
1125
1117
n_kernel:: Int
1126
- u_dim :: Int
1118
+ e_dim :: Int
1127
1119
dense_x:: Dense
1128
1120
end
1129
1121
1130
1122
Flux. @functor GMMConv
1131
1123
1132
1124
function GMMConv (ch:: Pair{Int, Int} ,
1133
1125
n_kernel:: Int ,
1134
- u_dim :: Int ,
1126
+ e_dim :: Int ,
1135
1127
σ= identity;
1136
1128
init= Flux. glorot_uniform,
1137
1129
bias:: Bool = true )
1138
1130
in, out = ch
1139
- mu = init (n_kernel, u_dim )
1140
- sigma_inv = init (n_kernel, u_dim )
1131
+ mu = init (n_kernel, e_dim )
1132
+ sigma_inv = init (n_kernel, e_dim )
1141
1133
b = bias ? Flux. create_bias (ones (out), true ) : false
1142
1134
dense_x = Dense (in, out* n_kernel, bias= false )
1143
- GMMConv (mu, sigma_inv, b, σ, ch, n_kernel, u_dim , dense_x)
1135
+ GMMConv (mu, sigma_inv, b, σ, ch, n_kernel, e_dim , dense_x)
1144
1136
end
1145
1137
1146
1138
function (l:: GMMConv )(g:: GNNGraph , x:: AbstractMatrix , u:: AbstractMatrix )
1147
1139
1148
1140
1149
- @assert (l. u_dim == size (u)[1 ] && g. num_edges == size (u)[2 ]) " Pseudo-cordinate dim $(size (u)) does not match (u_dim =$(u_dim ) ,num_edge=$(g. num_edges) )"
1141
+ @assert (l. e_dim == size (u)[1 ] && g. num_edges == size (u)[2 ]) " Pseudo-cordinate dim $(size (u)) does not match (e_dim =$(e_dim ) ,num_edge=$(g. num_edges) )"
1150
1142
1151
1143
num_edges = g. num_edges
1152
1144
d = degree (g, dir= :in )
1153
- u = reshape (u, (num_edges, 1 , l. u_dim))
1154
- mu = reshape (l. mu, (1 , l. n_kernel, l. u_dim))
1155
-
1156
- w = - 0.5 * (u.- mu). ^ 2
1157
- w = w .* ((reshape (l. sigma_inv, (1 , l. n_kernel, l. u_dim)).^ 2 ) )
1158
- w = exp .(sum (w, dims = 3 )) # n_edges, n_kernel, 1
1159
- w = permutedims (w, [3 ,2 ,1 ])
1160
-
1161
- xj = reshape (l. dense_x (x), (l. ch[2 ],l. n_kernel,:))
1145
+ u = reshape (u, (l. e_dim, 1 , num_edges))
1146
+ mu = reshape (l. mu, (l. e_dim, l. n_kernel, 1 ))
1162
1147
1163
- x = propagate (e_mul_xj, g, + , xj= xj, e= w)
1164
- x = dropdims (mean (x, dims= 2 ), dims= 2 )
1165
- x = 1 / d .* x
1148
+ e = - 0.5 * (u.- mu). ^ 2
1149
+ e = e .* ((reshape (l. sigma_inv, (l. e_dim, l. n_kernel, 1 )).^ 2 ) )
1150
+ e = exp .(sum (e, dims = 1 )) # (1, n_kernel, num_edge)
1151
+
1152
+ xj = reshape (l. dense_x (x), (l. ch[2 ],l. n_kernel,:)) # (out, n_kernel, num_nodes)
1153
+ x = propagate (e_mul_xj, g, + , xj= xj, e= e)
1154
+ x = dropdims (mean (x, dims= 2 ), dims= 2 ) # (out, num_nodes)
1155
+ x = 1 / d .* x
1166
1156
1167
1157
return l. σ (x .+ l. bias)
1168
1158
end
1169
1159
1170
1160
function Base. show (io:: IO , l:: GMMConv )
1171
- in, out, n_kernel, u_dim = l. ch[1 ], l. ch[2 ], l. n_kernel, l. u_dim
1161
+ in, out, n_kernel, e_dim = l. ch[1 ], l. ch[2 ], l. n_kernel, l. e_dim
1172
1162
print (io, " GMMConv(" , in, " => " , out)
1173
1163
print (io, " , n_kernel= " , n_kernel)
1174
- print (io, " , pseudo-cordinate dimension = " , u_dim )
1164
+ print (io, " , pseudo-cordinate dimension = " , e_dim )
1175
1165
print (io, " )" )
1176
1166
1177
1167
end
0 commit comments