@@ -1066,28 +1066,31 @@ function (l::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
1066
1066
end
1067
1067
1068
1068
@doc raw """
1069
- GMMConv(in => out, K, e_dim, σ=identity; [init , bias] )
1069
+ GMMConv((in, ein) => out, σ=identity; K=1 , bias=true, init=glorot_uniform, residual=false )
1070
1070
Graph mixture model convolution layer from the paper [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402)
1071
1071
Performs the operation
1072
1072
```math
1073
1073
\m athbf{x}_i' = \f rac{1}{|N(i)|} \s um_{j\i n N(i)}\f rac{1}{K}\s um_{k=1}^k \m athbf{w}_k(\m athbf{e}_{j\t o i}) \o dot \T heta_k \m athbf{x}_j
1074
1074
```
1075
1075
where
1076
1076
```math
1077
- w^a_{k}(e^a) = \e xp(\f rac{-1}{2}(e^a - \m u^a_k)^T \S igma^a_k^ {-1}(e^a - \m u^a_k))
1077
+ w^a_{k}(e^a) = \e xp(\f rac{-1}{2}(e^a - \m u^a_k)^T ( \S igma^{-1})^a_k (e^a - \m u^a_k))
1078
1078
```
1079
- $\T heta_k$, $\m u^a_k$, $\S igma^a_k^ {-1}$ are learnable parameters.
1079
+ $\T heta_k$, $\m u^a_k$, $\S igma^{-1})^a_k $ are learnable parameters.
1080
1080
1081
1081
The input to the layer is a node feature array 'X' of size `(num_features, num_nodes)` and
1082
1082
edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
1083
+
1083
1084
# Arguments
1084
- - `in`: Number of input features.
1085
+
1086
+ - `in`: Number of input node features.
1087
+ - `ein`: Number of input edge features.
1085
1088
- `out`: Number of output features.
1086
- - `K` : Number of kernels. Default `1`.
1087
- - `e_dim` : Dimensionality of pseudo coordinates. Has to correspond to the edge features dimension in the forward pass.
1088
1089
- `σ`: Activation function. Default `identity`.
1090
+ - `K`: Number of kernels. Default `1`.
1089
1091
- `bias`: Add learnable bias. Default `true`.
1090
1092
- `init`: Weights' initializer. Default `glorot_uniform`.
1093
+ - `residual`: Residual conncetion. Default `false`
1091
1094
1092
1095
#Examples
1093
1096
@@ -1096,12 +1099,12 @@ edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
1096
1099
s = [1,1,2,3]
1097
1100
t = [2,3,1,1]
1098
1101
g = GNNGraph(s,t)
1099
- in_feature, out_feature, K, e_dim = 4, 7, 8, 10
1100
- x = randn(in_feature , g.num_nodes)
1101
- e = randn(e_dim , g.num_edges)
1102
+ nin, ein, out, K = 4, 10, 7, 8
1103
+ x = randn(nin , g.num_nodes)
1104
+ e = randn(ein , g.num_edges)
1102
1105
1103
1106
# create layer
1104
- l = GMMConv(in_feature=>out_feature, K, e_dim )
1107
+ l = GMMConv((nin, ein) => out, K=K )
1105
1108
1106
1109
# forward pass
1107
1110
l(g, x, e)
@@ -1112,52 +1115,66 @@ struct GMMConv{A<:AbstractMatrix, B, F} <:GNNLayer
1112
1115
sigma_inv:: A
1113
1116
bias:: B
1114
1117
σ:: F
1115
- ch:: Pair{NTuple{2,Int}, Int}
1118
+ ch:: Pair{NTuple{2,Int},Int}
1116
1119
K:: Int
1117
1120
dense_x:: Dense
1121
+ residual:: Bool
1118
1122
end
1119
1123
1120
1124
@functor GMMConv
1121
1125
1122
- function GMMConv (ch:: Pair{NTuple{2,Int}, Int} ,
1126
+ function GMMConv (ch:: Pair{NTuple{2,Int},Int} ,
1123
1127
σ= identity;
1124
1128
K:: Int = 1 ,
1129
+ bias:: Bool = true ,
1125
1130
init= Flux. glorot_uniform,
1126
- bias:: Bool = true )
1131
+ residual= false )
1132
+
1127
1133
(nin, ein), out = ch
1128
1134
mu = init (ein, K)
1129
1135
sigma_inv = init (K, ein)
1130
1136
b = bias ? Flux. create_bias (ones (out), true ) : false
1131
- dense_x = Dense (in , out* K, bias= false )
1132
- GMMConv (mu, sigma_inv, b, σ, ch, K, dense_x)
1137
+ dense_x = Dense (nin , out* K, bias= false )
1138
+ GMMConv (mu, sigma_inv, b, σ, ch, K, dense_x, residual )
1133
1139
end
1134
1140
1135
- function (l:: GMMConv )(g:: GNNGraph , x:: AbstractMatrix , u:: AbstractMatrix )
1141
+ function (l:: GMMConv )(g:: GNNGraph , x:: AbstractMatrix , e:: AbstractMatrix )
1142
+ (nin, ein), out = l. ch # Notational Simplicity
1136
1143
1137
- @assert (l . ein == size (u )[1 ] && g. num_edges == size (u )[2 ]) " Pseudo-cordinate dim $(size (u)) does not match (ein=$(ein) ,num_edge=$(g. num_edges) )"
1144
+ @assert (ein == size (e )[1 ] && g. num_edges == size (e )[2 ]) " Pseudo-cordinate dim $(size (u)) does not match (ein=$(ein) ,num_edge=$(g. num_edges) )"
1138
1145
1139
1146
num_edges = g. num_edges
1140
1147
d = degree (g, dir= :in )
1141
- u = reshape (u , (l . ein, 1 , num_edges))
1142
- mu = reshape (l. mu, (l . ein, l. K, 1 ))
1148
+ w = reshape (e , (ein, 1 , num_edges))
1149
+ mu = reshape (l. mu, (ein, l. K, 1 ))
1143
1150
1144
- e = - 0.5 * (u.- mu). ^ 2
1145
- e = e .* ((reshape (l. sigma_inv, (l. ein, l. K, 1 )).^ 2 ) )
1146
- e = exp .(sum (e, dims = 1 )) # (1, K, num_edge)
1147
-
1148
- xj = reshape (l. dense_x (x), (l. ch[2 ],l. K,:)) # (out, K, num_nodes)
1149
- x = propagate (e_mul_xj, g, + , xj= xj, e= e)
1150
- x = dropdims (mean (x, dims= 2 ), dims= 2 ) # (out, num_nodes)
1151
- x = 1 / d .* x
1151
+ w = - 0.5 * (w.- mu). ^ 2
1152
+ w = w .* reshape (l. sigma_inv, (ein, l. K, 1 ))
1153
+ w = exp .(sum (w, dims = 1 )) # (1, K, num_edge)
1154
+
1155
+ xj = reshape (l. dense_x (x), (out,l. K,:)) # (out, K, num_nodes)
1156
+ m = propagate (e_mul_xj, g, + , xj= xj, e= w)
1157
+ m = dropdims (mean (m, dims= 2 ), dims= 2 ) # (out, num_nodes)
1158
+ m = 1 / d .* m
1159
+
1160
+ m = l. σ (m .+ l. bias)
1161
+
1162
+ if l. residual
1163
+ if size (x,1 )== size (m,1 )
1164
+ m+= x
1165
+ else
1166
+ @warn " Residual not applied : output feature $(size (m,1 )) !== input_feature $(size (x,1 )) "
1167
+ end
1168
+ end
1152
1169
1153
- return l . σ (x .+ l . bias)
1170
+ return m
1154
1171
end
1155
1172
1156
1173
function Base. show (io:: IO , l:: GMMConv )
1157
- in, out, K, ein = l. ch[ 1 ], l . ch[ 2 ], l . K, l . ein
1158
- print (io, " GMMConv(" , in , " => " , out)
1159
- print (io, " , K=" , K)
1160
- print (io, " , ein =" , ein )
1174
+ (nin, ein), out = l. ch
1175
+ print (io, " GMMConv(( " , nin , " , " , ein, " )=> " , out)
1176
+ print (io, " , K=" , l . K)
1177
+ print (io, " , σ =" , l . σ )
1161
1178
print (io, " )" )
1162
1179
1163
1180
end
0 commit comments