11"""
2- OperatorConv([rng::AbstractRNG = __defautl_rng ()], ch::Pair{<:Integer, <:Integer},
3- _modes ::NTuple{N, <:Integer}, ::Type{TR}; init_weight = glorot_uniform,
2+ OperatorConv([rng::AbstractRNG = __default_rng ()], ch::Pair{<:Integer, <:Integer},
3+ modes ::NTuple{N, <:Integer}, ::Type{TR}; init_weight = glorot_uniform,
44 T::Type{TP} = ComplexF32,
55 permuted::Val{P} = Val(false)) where {N, TR <: AbstractTransform, TP, P}
66
@@ -40,13 +40,13 @@ function OperatorConv(rng::AbstractRNG, ch::Pair{<:Integer, <:Integer},
4040 name = " OperatorConv{$TR }($in_chs => $out_chs , $modes ; permuted = $permuted )"
4141
4242 if permuted === True
43- return @compact (; modes, weights, transform,
43+ return @compact (; modes, weights, transform, dispatch = :OperatorConv ,
4444 name) do x:: AbstractArray{<:Real, M} where {M}
4545 y = __operator_conv (x, transform, weights)
4646 return y
4747 end
4848 else
49- return @compact (; modes, weights, transform,
49+ return @compact (; modes, weights, transform, dispatch = :OperatorConv ,
5050 name) do x:: AbstractArray{<:Real, M} where {M}
5151 N_ = ndims (transform)
5252 xᵀ = permutedims (x, (ntuple (i -> i + 1 , N_)... , 1 , N_ + 2 ))
@@ -57,8 +57,6 @@ function OperatorConv(rng::AbstractRNG, ch::Pair{<:Integer, <:Integer},
5757 end
5858end
5959
60- OperatorConv (args... ; kwargs... ) = OperatorConv (__default_rng (), args... ; kwargs... )
61-
6260"""
6361 SpectralConv(args...; kwargs...)
6462
@@ -78,7 +76,7 @@ OperatorConv{FourierTransform}(2 => 5, (16,); permuted = true)() # 160 paramete
7876SpectralConv (args... ; kwargs... ) = OperatorConv (args... , FourierTransform; kwargs... )
7977
8078"""
81- OperatorKernel([rng::AbstractRNG = __defautl_rng ()], ch::Pair{<:Integer, <:Integer},
79+ OperatorKernel([rng::AbstractRNG = __default_rng ()], ch::Pair{<:Integer, <:Integer},
8280 modes::NTuple{N, <:Integer}, transform::Type{TR}; σ = identity,
8381 permuted::Val{P} = Val(false), kwargs...) where {N, TR <: AbstractTransform, P}
8482
@@ -128,7 +126,8 @@ function OperatorKernel(rng::AbstractRNG, ch::Pair{<:Integer, <:Integer},
128126 l₁ = permuted === True ? Conv (map (_ -> 1 , modes), ch) : Dense (ch)
129127 l₂ = OperatorConv (rng, ch, modes, transform; permuted, kwargs... )
130128
131- return @compact (; l₁, l₂, activation= σ) do x:: AbstractArray{<:Real, M} where {M}
129+ return @compact (; l₁, l₂, activation= σ,
130+ dispatch= :OperatorKernel ) do x:: AbstractArray{<:Real, M} where {M}
132131 return activation .(l₁ (x) .+ l₂ (x))
133132 end
134133end
@@ -165,7 +164,60 @@ end # Total: 175 parameters,
165164"""
166165SpectralKernel (args... ; kwargs... ) = OperatorKernel (args... , FourierTransform; kwargs... )
167166
168- OperatorKernel (args... ; kwargs... ) = OperatorKernel (__default_rng (), args... ; kwargs... )
167+ # Building Blocks
168+ function BasicBlock (rng:: AbstractRNG , ch:: Integer , modes:: NTuple{N, <:Integer} , args... ;
169+ add_mlp:: Val = False, normalize:: Val = False, σ= gelu, kwargs... ) where {N}
170+ conv1 = SpectralConv (rng, ch => ch, modes, args... ; kwargs... )
171+ conv2 = SpectralConv (rng, ch => ch, modes, args... ; kwargs... )
172+ conv3 = SpectralConv (rng, ch => ch, modes, args... ; kwargs... )
173+
174+ kernel_size = map (_ -> 1 , modes)
175+
176+ if add_mlp === True
177+ mlp1 = Chain (Conv (kernel_size, ch => ch, σ), Conv (kernel_size, ch => ch))
178+ mlp2 = Chain (Conv (kernel_size, ch => ch, σ), Conv (kernel_size, ch => ch))
179+ mlp3 = Chain (Conv (kernel_size, ch => ch, σ), Conv (kernel_size, ch => ch))
180+ else
181+ mlp1 = NoOpLayer ()
182+ mlp2 = NoOpLayer ()
183+ mlp3 = NoOpLayer ()
184+ end
185+
186+ norm = normalize === True ? InstanceNorm (ch; affine = false ) : NoOpLayer ()
187+
188+ w1 = Conv (kernel_size, ch => ch)
189+ w2 = Conv (kernel_size, ch => ch)
190+ w3 = Conv (kernel_size, ch => ch)
191+
192+ return @compact (; conv1, conv2, conv3, mlp1, mlp2, mlp3, w1, w2, w3, norm, σ,
193+ dispatch= :BasicBlock ) do (inp,)
194+ x, injection = __destructure (inp)
195+
196+ x = norm (x)
197+ x1 = norm (mlp1 (norm (conv1 (x))))
198+ x2 = norm (w1 (x))
199+ x = σ .(x1 .+ x2 .+ injection)
200+
201+ x1 = norm (mlp2 (norm (conv2 (x))))
202+ x2 = norm (w2 (x))
203+ x = σ .(x1 .+ x2 .+ injection)
204+
205+ x1 = norm (mlp3 (norm (conv3 (x))))
206+ x2 = norm (w3 (x))
207+ return σ .(x1 .+ x2 .+ injection)
208+ end
209+ end
210+
211+ function StackedBasicBlock (rng:: AbstractRNG , args... ; depth:: Val{N} = Val (1 ),
212+ kwargs... ) where {N}
213+ blocks = ntuple (i -> BasicBlock (rng, args... ; kwargs... ), N)
214+ block = NamedTuple {ntuple(i -> Symbol("block_", i), N)} (blocks)
215+
216+ return @compact (; block, dispatch= :StackedBasicBlock ) do (inp,)
217+ x, injection = __destructure (inp)
218+ return __applyblock (block, x, injection)
219+ end
220+ end
169221
170222# Functional Versions
171223@inline function __operator_conv (x, transform, weights)
191243 return permutedims (res, (3 , 1 , 2 )) # m x o x b
192244end
193245
246+ @inline @generated function __applyblock (block:: NamedTuple{names} , x, inj) where {names}
247+ calls = [:(x = block.$ name ((x, inj))) for name in names]
248+ return quote
249+ $ (calls... )
250+ end
251+ end
252+
194253@inline __pad_modes (x, dims:: Integer... ) = __pad_modes (x, dims)
195254@inline __pad_modes (x, dims:: NTuple ) = __pad_modes! (similar (x, dims), x)
196255
0 commit comments