|
1 | 1 | export
|
2 | 2 | OperatorConv,
|
3 | 3 | SpectralConv,
|
4 |
| - OperatorKernel |
| 4 | + OperatorKernel, |
| 5 | + GraphKernel |
5 | 6 |
|
6 | 7 | struct OperatorConv{P, T, S, TT}
|
7 | 8 | weight::T
|
@@ -170,6 +171,52 @@ function (m::OperatorKernel)(𝐱)
|
170 | 171 | return m.σ.(m.linear(𝐱) + m.conv(𝐱))
|
171 | 172 | end
|
172 | 173 |
|
| 174 | +""" |
| 175 | + GraphKernel(κ, ch, σ=identity) |
| 176 | +
|
| 177 | +Graph kernel layer. |
| 178 | +
|
| 179 | +## Arguments |
| 180 | +
|
| 181 | +* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP. |
| 182 | +* `ch`: Channel size for linear transform, e.g. `32`. |
| 183 | +* `σ`: Activation function. |
| 184 | +""" |
| 185 | +struct GraphKernel{A,B,F} <: MessagePassing |
| 186 | + linear::A |
| 187 | + κ::B |
| 188 | + σ::F |
| 189 | +end |
| 190 | + |
| 191 | +function GraphKernel(κ, ch::Int, σ=identity; init=Flux.glorot_uniform) |
| 192 | + W = init(ch, ch) |
| 193 | + return GraphKernel(W, κ, σ) |
| 194 | +end |
| 195 | + |
| 196 | +Flux.@functor GraphKernel |
| 197 | + |
| 198 | +function GeometricFlux.message(l::GraphKernel, x_i::AbstractArray, x_j::AbstractArray, e_ij) |
| 199 | + return l.κ(vcat(x_i, x_j)) |
| 200 | +end |
| 201 | + |
| 202 | +function GeometricFlux.update(l::GraphKernel, m::AbstractArray, x::AbstractArray) |
| 203 | + return l.σ.(GeometricFlux._matmul(l.linear, x) + m) |
| 204 | +end |
| 205 | + |
| 206 | +function (l::GraphKernel)(el::NamedTuple, X::AbstractArray) |
| 207 | + GraphSignals.check_num_nodes(el.N, X) |
| 208 | + _, V, _ = GeometricFlux.propagate(l, el, nothing, X, nothing, mean, nothing, nothing) |
| 209 | + return V |
| 210 | +end |
| 211 | + |
| 212 | +function Base.show(io::IO, l::GraphKernel) |
| 213 | + channel, _ = size(l.linear) |
| 214 | + print(io, "GraphKernel(", l.κ, ", channel=", channel) |
| 215 | + l.σ == identity || print(io, ", ", l.σ) |
| 216 | + print(io, ")") |
| 217 | +end |
| 218 | + |
| 219 | + |
173 | 220 | #########
|
174 | 221 | # utils #
|
175 | 222 | #########
|
|
0 commit comments