Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit f42964c

Browse files
committed
Add more building blocks
1 parent f566776 commit f42964c

File tree

5 files changed

+93
-22
lines changed

5 files changed

+93
-22
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1212
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
15+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1516

1617
[compat]
1718
ChainRulesCore = "1"

src/LuxNeuralOperators.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ import PrecompileTools: @recompile_invalidations
44
import Reexport: @reexport
55

66
@recompile_invalidations begin
7-
using ArrayInterface, FFTW, Lux, Random
7+
using ArrayInterface, FFTW, Lux, Random, SciMLBase
88

99
import ChainRulesCore as CRC
1010
import Lux.Experimental: @compact
11-
import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer,
12-
initialparameters, initialstates
11+
import LuxCore: AbstractExplicitLayer,
12+
AbstractExplicitContainerLayer, initialparameters, initialstates
1313
import Random: AbstractRNG
1414
end
1515

@@ -23,6 +23,18 @@ const False = Val(false)
2323
include("transform.jl")
2424
include("layers.jl")
2525
include("fno.jl")
26+
include("deq.jl")
27+
28+
# Pass `rng` if user doesn't pass it
29+
for f in (:BasicBlock, :StackedBasicBlock, :OperatorConv, :OperatorKernel,
30+
:FourierNeuralOperator)
31+
@eval begin
32+
$(f)(args...; kwargs...) = $(f)(__default_rng(), args...; kwargs...)
33+
end
34+
end
35+
36+
__destructure(x::Tuple) = x
37+
__destructure(x) = x, zero(eltype(x))
2638

2739
export FourierTransform
2840
export SpectralConv, OperatorConv

src/deq.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Basic DEQ Implementation

src/fno.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
FourierNeuralOperator(; chs = (2, 64, 64, 64, 64, 64, 128, 1), modes = (16,),
3-
σ = gelu, rng = nothing, permuted::Val = Val(false), kwargs...)
2+
FourierNeuralOperator([rng = __default_rng()]; chs = (2, 64, 64, 64, 64, 64, 128, 1),
3+
modes = (16,), σ = gelu, permuted::Val = Val(false), kwargs...)
44
55
Fourier neural operator is a operator learning model that uses Fourier kernel to perform
66
spectral convolutions. It is a promising way for surrogate methods, and can be regarded as
@@ -16,7 +16,6 @@ kernels, and two `Dense` layers to project data back to the scalar field of inte
1616
- `modes`: The modes to be preserved. A tuple of length `d`, where `d` is the dimension
1717
of data.
1818
- `σ`: Activation function for all layers in the model.
19-
- `rng`: Random number generator. If provided, it is forwarded to the Operator Layers.
2019
- `permuted`: Whether the dim is permuted. If `permuted = Val(true)`, the layer accepts
2120
data in the order of `(ch, x_1, ... , x_d , batch)`. Otherwise the order is
2221
`(x_1, ... , x_d, ch, batch)`.
@@ -35,28 +34,28 @@ Chain(
3534
l1 = Dense(64 => 64), # 4_160 parameters
3635
l2 = OperatorConv{FourierTransform}(64 => 64, (16,); permuted = false)(), # 65_536 parameters, plus 2
3736
activation = σ,
38-
) do x::(AbstractArray{<:Real, M} where M)
37+
) do x::(AbstractArray{<:Real, M} where M)
3938
return activation.(l1(x) .+ l2(x))
4039
end,
4140
layer_2 = @compact(
4241
l1 = Dense(64 => 64), # 4_160 parameters
4342
l2 = OperatorConv{FourierTransform}(64 => 64, (16,); permuted = false)(), # 65_536 parameters, plus 2
4443
activation = σ,
45-
) do x::(AbstractArray{<:Real, M} where M)
44+
) do x::(AbstractArray{<:Real, M} where M)
4645
return activation.(l1(x) .+ l2(x))
4746
end,
4847
layer_3 = @compact(
4948
l1 = Dense(64 => 64), # 4_160 parameters
5049
l2 = OperatorConv{FourierTransform}(64 => 64, (16,); permuted = false)(), # 65_536 parameters, plus 2
5150
activation = σ,
52-
) do x::(AbstractArray{<:Real, M} where M)
51+
) do x::(AbstractArray{<:Real, M} where M)
5352
return activation.(l1(x) .+ l2(x))
5453
end,
5554
layer_4 = @compact(
5655
l1 = Dense(64 => 64), # 4_160 parameters
5756
l2 = OperatorConv{FourierTransform}(64 => 64, (16,); permuted = false)(), # 65_536 parameters, plus 2
5857
activation = σ,
59-
) do x::(AbstractArray{<:Real, M} where M)
58+
) do x::(AbstractArray{<:Real, M} where M)
6059
return activation.(l1(x) .+ l2(x))
6160
end,
6261
),
@@ -68,10 +67,9 @@ Chain(
6867
# plus 12 states.
6968
```
7069
"""
71-
function FourierNeuralOperator(; chs=(2, 64, 64, 64, 64, 64, 128, 1), modes=(16,),
72-
σ=gelu, rng=nothing, permuted::Val{P}=False, kwargs...) where {P}
70+
function FourierNeuralOperator(rng::AbstractRNG; chs=(2, 64, 64, 64, 64, 64, 128, 1),
71+
modes=(16,), σ=gelu, permuted::Val{P}=False, kwargs...) where {P}
7372
@assert length(chs) 5
74-
rng === nothing && (rng = __default_rng())
7573

7674
map₁ = chs[1] => chs[2]
7775
map₂ = chs[end - 2] => chs[end - 1]

src/layers.jl

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
OperatorConv([rng::AbstractRNG = __defautl_rng()], ch::Pair{<:Integer, <:Integer},
3-
_modes::NTuple{N, <:Integer}, ::Type{TR}; init_weight = glorot_uniform,
2+
OperatorConv([rng::AbstractRNG = __default_rng()], ch::Pair{<:Integer, <:Integer},
3+
modes::NTuple{N, <:Integer}, ::Type{TR}; init_weight = glorot_uniform,
44
T::Type{TP} = ComplexF32,
55
permuted::Val{P} = Val(false)) where {N, TR <: AbstractTransform, TP, P}
66
@@ -40,13 +40,13 @@ function OperatorConv(rng::AbstractRNG, ch::Pair{<:Integer, <:Integer},
4040
name = "OperatorConv{$TR}($in_chs => $out_chs, $modes; permuted = $permuted)"
4141

4242
if permuted === True
43-
return @compact(; modes, weights, transform,
43+
return @compact(; modes, weights, transform, dispatch=:OperatorConv,
4444
name) do x::AbstractArray{<:Real, M} where {M}
4545
y = __operator_conv(x, transform, weights)
4646
return y
4747
end
4848
else
49-
return @compact(; modes, weights, transform,
49+
return @compact(; modes, weights, transform, dispatch=:OperatorConv,
5050
name) do x::AbstractArray{<:Real, M} where {M}
5151
N_ = ndims(transform)
5252
xᵀ = permutedims(x, (ntuple(i -> i + 1, N_)..., 1, N_ + 2))
@@ -57,8 +57,6 @@ function OperatorConv(rng::AbstractRNG, ch::Pair{<:Integer, <:Integer},
5757
end
5858
end
5959

60-
OperatorConv(args...; kwargs...) = OperatorConv(__default_rng(), args...; kwargs...)
61-
6260
"""
6361
SpectralConv(args...; kwargs...)
6462
@@ -78,7 +76,7 @@ OperatorConv{FourierTransform}(2 => 5, (16,); permuted = true)() # 160 paramete
7876
SpectralConv(args...; kwargs...) = OperatorConv(args..., FourierTransform; kwargs...)
7977

8078
"""
81-
OperatorKernel([rng::AbstractRNG = __defautl_rng()], ch::Pair{<:Integer, <:Integer},
79+
OperatorKernel([rng::AbstractRNG = __default_rng()], ch::Pair{<:Integer, <:Integer},
8280
modes::NTuple{N, <:Integer}, transform::Type{TR}; σ = identity,
8381
permuted::Val{P} = Val(false), kwargs...) where {N, TR <: AbstractTransform, P}
8482
@@ -128,7 +126,8 @@ function OperatorKernel(rng::AbstractRNG, ch::Pair{<:Integer, <:Integer},
128126
l₁ = permuted === True ? Conv(map(_ -> 1, modes), ch) : Dense(ch)
129127
l₂ = OperatorConv(rng, ch, modes, transform; permuted, kwargs...)
130128

131-
return @compact(; l₁, l₂, activation=σ) do x::AbstractArray{<:Real, M} where {M}
129+
return @compact(; l₁, l₂, activation=σ,
130+
dispatch=:OperatorKernel) do x::AbstractArray{<:Real, M} where {M}
132131
return activation.(l₁(x) .+ l₂(x))
133132
end
134133
end
@@ -165,7 +164,60 @@ end # Total: 175 parameters,
165164
"""
166165
SpectralKernel(args...; kwargs...) = OperatorKernel(args..., FourierTransform; kwargs...)
167166

168-
OperatorKernel(args...; kwargs...) = OperatorKernel(__default_rng(), args...; kwargs...)
167+
# Building Blocks
168+
function BasicBlock(rng::AbstractRNG, ch::Integer, modes::NTuple{N, <:Integer}, args...;
169+
add_mlp::Val=False, normalize::Val=False, σ=gelu, kwargs...) where {N}
170+
conv1 = SpectralConv(rng, ch => ch, modes, args...; kwargs...)
171+
conv2 = SpectralConv(rng, ch => ch, modes, args...; kwargs...)
172+
conv3 = SpectralConv(rng, ch => ch, modes, args...; kwargs...)
173+
174+
kernel_size = map(_ -> 1, modes)
175+
176+
if add_mlp === True
177+
mlp1 = Chain(Conv(kernel_size, ch => ch, σ), Conv(kernel_size, ch => ch))
178+
mlp2 = Chain(Conv(kernel_size, ch => ch, σ), Conv(kernel_size, ch => ch))
179+
mlp3 = Chain(Conv(kernel_size, ch => ch, σ), Conv(kernel_size, ch => ch))
180+
else
181+
mlp1 = NoOpLayer()
182+
mlp2 = NoOpLayer()
183+
mlp3 = NoOpLayer()
184+
end
185+
186+
norm = normalize === True ? InstanceNorm(ch; affine = false) : NoOpLayer()
187+
188+
w1 = Conv(kernel_size, ch => ch)
189+
w2 = Conv(kernel_size, ch => ch)
190+
w3 = Conv(kernel_size, ch => ch)
191+
192+
return @compact(; conv1, conv2, conv3, mlp1, mlp2, mlp3, w1, w2, w3, norm, σ,
193+
dispatch=:BasicBlock) do (inp,)
194+
x, injection = __destructure(inp)
195+
196+
x = norm(x)
197+
x1 = norm(mlp1(norm(conv1(x))))
198+
x2 = norm(w1(x))
199+
x = σ.(x1 .+ x2 .+ injection)
200+
201+
x1 = norm(mlp2(norm(conv2(x))))
202+
x2 = norm(w2(x))
203+
x = σ.(x1 .+ x2 .+ injection)
204+
205+
x1 = norm(mlp3(norm(conv3(x))))
206+
x2 = norm(w3(x))
207+
return σ.(x1 .+ x2 .+ injection)
208+
end
209+
end
210+
211+
function StackedBasicBlock(rng::AbstractRNG, args...; depth::Val{N} = Val(1),
212+
kwargs...) where {N}
213+
blocks = ntuple(i -> BasicBlock(rng, args...; kwargs...), N)
214+
block = NamedTuple{ntuple(i -> Symbol("block_", i), N)}(blocks)
215+
216+
return @compact(; block, dispatch=:StackedBasicBlock) do (inp,)
217+
x, injection = __destructure(inp)
218+
return __applyblock(block, x, injection)
219+
end
220+
end
169221

170222
# Functional Versions
171223
@inline function __operator_conv(x, transform, weights)
@@ -191,6 +243,13 @@ end
191243
return permutedims(res, (3, 1, 2)) # m x o x b
192244
end
193245

246+
@inline @generated function __applyblock(block::NamedTuple{names}, x, inj) where {names}
247+
calls = [:(x = block.$name((x, inj))) for name in names]
248+
return quote
249+
$(calls...)
250+
end
251+
end
252+
194253
@inline __pad_modes(x, dims::Integer...) = __pad_modes(x, dims)
195254
@inline __pad_modes(x, dims::NTuple) = __pad_modes!(similar(x, dims), x)
196255

0 commit comments

Comments
 (0)