Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 61e944e

Browse files
committed
add GraphKernel
add doc fix
1 parent 8ca0713 commit 61e944e

File tree

6 files changed

+92
-4
lines changed

6 files changed

+92
-4
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1111
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
12+
GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea"
1213
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
14+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1315
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1416
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1517

@@ -19,13 +21,15 @@ CUDAKernels = "0.3, 0.4"
1921
ChainRulesCore = "1.13"
2022
FFTW = "1.4"
2123
Flux = "0.12"
24+
GeometricFlux = "0.10"
2225
KernelAbstractions = "0.7, 0.8"
2326
Tullio = "0.3"
2427
Zygote = "0.6"
2528
julia = "1.6"
2629

2730
[extras]
31+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
2832
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2933

3034
[targets]
31-
test = ["Test"]
35+
test = ["Graphs", "Test"]

docs/src/apis.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,21 @@ OperatorKernel
4040

4141
Reference: [Fourier Neural Operator for Parametric Partial Differential Equations](https://arxiv.org/abs/2010.08895)
4242

43+
### Graph kernel layer
44+
45+
```math
46+
v_{t+1}(x_i) = \sigma(W v_t(x_i) + \frac{1}{|\mathcal{N}(x_i)|} \sum_{x_j \in \mathcal{N}(x_i)} \kappa \{ v_t(x_i), v_t(x_j) \} )
47+
```
48+
49+
where ``v_t(x_i)`` is the input function for ``t``-th layer, ``x_i`` is the node feature for ``i``-th node and ``\mathcal{N}(x_i)`` represents the neighbors for ``x_i``.
50+
Activation function ``\sigma`` can be arbitrary non-linear function.
51+
52+
```@docs
53+
GraphKernel
54+
```
55+
56+
Reference: [Neural Operator: Graph Kernel Network for Partial Differential Equations](https://arxiv.org/abs/2003.03485)
57+
4358
## Models
4459

4560
### Fourier neural operator

src/NeuralOperators.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ module NeuralOperators
77
using KernelAbstractions
88
using Zygote
99
using ChainRulesCore
10+
using GeometricFlux
11+
using Statistics
1012

1113
export DeepONet
1214

src/operator_kernel.jl

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
export
22
OperatorConv,
33
SpectralConv,
4-
OperatorKernel
4+
OperatorKernel,
5+
GraphKernel
56

67
struct OperatorConv{P, T, S, TT}
78
weight::T
@@ -170,6 +171,52 @@ function (m::OperatorKernel)(𝐱)
170171
return m.σ.(m.linear(𝐱) + m.conv(𝐱))
171172
end
172173

174+
"""
175+
GraphKernel(κ, ch, σ=identity)
176+
177+
Graph kernel layer.
178+
179+
## Arguments
180+
181+
* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP.
182+
* `ch`: Channel size for linear transform, e.g. `32`.
183+
* `σ`: Activation function.
184+
"""
185+
struct GraphKernel{A,B,F} <: MessagePassing
186+
linear::A
187+
κ::B
188+
σ::F
189+
end
190+
191+
function GraphKernel(κ, ch::Int, σ=identity; init=Flux.glorot_uniform)
192+
W = init(ch, ch)
193+
return GraphKernel(W, κ, σ)
194+
end
195+
196+
Flux.@functor GraphKernel
197+
198+
function GeometricFlux.message(l::GraphKernel, x_i::AbstractArray, x_j::AbstractArray, e_ij)
199+
return l.κ(vcat(x_i, x_j))
200+
end
201+
202+
function GeometricFlux.update(l::GraphKernel, m::AbstractArray, x::AbstractArray)
203+
return l.σ.(GeometricFlux._matmul(l.linear, x) + m)
204+
end
205+
206+
function (l::GraphKernel)(el::NamedTuple, X::AbstractArray)
207+
GraphSignals.check_num_nodes(el.N, X)
208+
_, V, _ = GeometricFlux.propagate(l, el, nothing, X, nothing, mean, nothing, nothing)
209+
return V
210+
end
211+
212+
function Base.show(io::IO, l::GraphKernel)
213+
channel, _ = size(l.linear)
214+
print(io, "GraphKernel(", l.κ, ", channel=", channel)
215+
l.σ == identity || print(io, ", ", l.σ)
216+
print(io, ")")
217+
end
218+
219+
173220
#########
174221
# utils #
175222
#########

test/operator_kernel.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,20 @@ end
154154
@test SpectralConv(ch, modes) isa OperatorConv
155155
@test SpectralConv(ch, modes).transform isa FourierTransform
156156
end
157+
158+
@testset "GraphKernel" begin
159+
batch_size = 5
160+
channel = 32
161+
N = 10*10
162+
163+
κ = Dense(2*channel, channel, relu)
164+
165+
graph = grid([10, 10])
166+
𝐱 = rand(Float32, channel, N, batch_size)
167+
l = WithGraph(FeaturedGraph(graph), GraphKernel(κ, channel))
168+
@test repr(l.layer) == "GraphKernel(Dense(64, 32, relu), channel=32)"
169+
@test size(l(𝐱)) == (channel, N, batch_size)
170+
171+
g = Zygote.gradient(() -> sum(l(𝐱)), Flux.params(l))
172+
@test length(g.grads) == 3
173+
end

test/runtests.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
using NeuralOperators
2-
using Test
3-
using Flux
42
using CUDA
3+
using Flux
4+
using GeometricFlux
5+
using Graphs
6+
using Zygote
7+
using Test
58

69
CUDA.allowscalar(false)
710

0 commit comments

Comments
 (0)