11"""
2- OperatorConv(ch::Pair{<:Integer, <:Integer}, modes::NTuple{N, <:Integer}, ::Type{TR};
3- init_weight = glorot_uniform, T ::Type{TP} = ComplexF32 ,
4- permuted::Val{P} = Val(false)) where {N, TR <: AbstractTransform, TP, P }
2+ OperatorConv(ch::Pair{<:Integer, <:Integer}, modes::NTuple{N, <:Integer},
3+ ::Type{TR}; init_weight=glorot_uniform ,
4+ permuted::Val{perm}= Val(false)) where {N, TR <: AbstractTransform, perm }
55
66## Arguments
77
88 - `ch`: A `Pair` of input and output channel size `ch_in => ch_out`, e.g. `64 => 64`.
99 - `modes`: The modes to be preserved. A tuple of length `d`, where `d` is the dimension of
1010 data.
11- - `::Type{TR}`: The traform to operate the transformation.
11+ - `::Type{TR}`: The transform to operate the transformation.
1212
1313## Keyword Arguments
1414
1515 - `init_weight`: Initial function to initialize parameters.
1616 - `permuted`: Whether the dim is permuted. If `permuted = Val(false)`, the layer accepts
1717 data in the order of `(ch, x_1, ... , x_d, batch)`. Otherwise the order is
1818 `(x_1, ... , x_d, ch, batch)`.
19- - `T`: Datatype of parameters.
2019
2120## Example
2221
2322```jldoctest
24- julia> OperatorConv(2 => 5, (16,), FourierTransform)
23+ julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32} )
2524OperatorConv{FourierTransform}(2 => 5, (16,); permuted = false)() # 160 parameters
2625
27- julia> OperatorConv(2 => 5, (16,), FourierTransform; permuted=Val(true))
26+ julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32} ; permuted=Val(true))
2827OperatorConv{FourierTransform}(2 => 5, (16,); permuted = true)() # 160 parameters
2928```
3029"""
31- @concrete struct OperatorConv{elType, perm, T <: AbstractTransform } <: AbstractExplicitLayer
30+ @concrete struct OperatorConv{perm, T <: AbstractTransform } <: AbstractExplicitLayer
3231 in_chs:: Int
3332 out_chs:: Int
3433 prod_modes:: Int
@@ -39,30 +38,30 @@ OperatorConv{FourierTransform}(2 => 5, (16,); permuted = true)() # 160 paramete
3938 name:: String
4039end
4140
42- function LuxCore. initialparameters (
43- rng:: AbstractRNG , layer:: OperatorConv{elType} ) where {elType}
41+ function LuxCore. initialparameters (rng:: AbstractRNG , layer:: OperatorConv )
4442 in_chs, out_chs = layer. in_chs, layer. out_chs
45- scale = real (one (elType)) / (in_chs * out_chs)
46- weights = scale * layer. init_weight (rng, elType, out_chs, in_chs, layer. prod_modes)
47- return (; weights,)
43+ scale = real (one (eltype (layer. tform))) / (in_chs * out_chs)
44+ return (;
45+ weights= scale * layer. init_weight (
46+ rng, eltype (layer. tform), out_chs, in_chs, layer. prod_modes))
4847end
4948
5049@inline function LuxCore. parameterlength (layer:: OperatorConv )
5150 return layer. prod_modes * layer. in_chs * layer. out_chs
5251end
5352
5453function OperatorConv (ch:: Pair{<:Integer, <:Integer} , modes:: NTuple{N, <:Integer} ,
55- :: Type{TR} ; init_weight= glorot_uniform, T :: Type{TP} = ComplexF32,
56- permuted:: Val{perm} = Val (false )) where {N, TR <: AbstractTransform , TP , perm}
54+ :: Type{TR} ; init_weight= glorot_uniform,
55+ permuted:: Val{perm} = Val (false )) where {N, TR <: AbstractTransform{<:Number} , perm}
5756 name = " OperatorConv{$(string (nameof (TR))) }($(ch[1 ]) => $(ch[2 ]) , $modes ; permuted = $perm )"
58- return OperatorConv {TP, perm} (ch... , prod (modes), TR (modes), init_weight, name)
57+ return OperatorConv {perm} (ch... , prod (modes), TR (modes), init_weight, name)
5958end
6059
61- function (conv:: OperatorConv{T, true} )(x:: AbstractArray{<:Real, M} , ps, st) where {T, M}
60+ function (conv:: OperatorConv{true} )(x:: AbstractArray , ps, st)
6261 return operator_conv (x, conv. tform, ps. weights), st
6362end
6463
65- function (conv:: OperatorConv{T, false} )(x:: AbstractArray{<:Real, M} , ps, st) where {T, M}
64+ function (conv:: OperatorConv{false} )(x:: AbstractArray , ps, st)
6665 N = ndims (conv. tform)
6766 xᵀ = permutedims (x, (ntuple (i -> i + 1 , N)... , 1 , N + 2 ))
6867 yᵀ = operator_conv (xᵀ, conv. tform, ps. weights)
7372"""
7473 SpectralConv(args...; kwargs...)
7574
76- Construct a `OperatorConv` with `FourierTransform` as the transform. See
75+ Construct a `OperatorConv` with `FourierTransform{ComplexF32} ` as the transform. See
7776[`OperatorConv`](@ref) for the individual arguments.
7877
7978## Example
@@ -86,7 +85,8 @@ julia> SpectralConv(2 => 5, (16,); permuted=Val(true))
8685OperatorConv{FourierTransform}(2 => 5, (16,); permuted = true)() # 160 parameters
8786```
8887"""
89- SpectralConv (args... ; kwargs... ) = OperatorConv (args... , FourierTransform; kwargs... )
88+ SpectralConv (args... ; kwargs... ) = OperatorConv (
89+ args... , FourierTransform{ComplexF32}; kwargs... )
9090
9191"""
9292 OperatorKernel(ch::Pair{<:Integer, <:Integer}, modes::Dims{N}, transform::Type{TR},
@@ -106,14 +106,13 @@ SpectralConv(args...; kwargs...) = OperatorConv(args..., FourierTransform; kwarg
106106 - `permuted`: Whether the dim is permuted. If `permuted = Val(true)`, the layer accepts
107107 data in the order of `(ch, x_1, ... , x_d , batch)`. Otherwise the order is
108108 `(x_1, ... , x_d, ch, batch)`.
109- - `T`: Datatype of parameters.
110109
111110All the keyword arguments are passed to the [`OperatorConv`](@ref) constructor.
112111
113112## Example
114113
115114```jldoctest
116- julia> OperatorKernel(2 => 5, (16,), FourierTransform)
115+ julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64} )
117116@compact(
118117 l₁ = Dense(2 => 5), # 15 parameters
119118 l₂ = OperatorConv{FourierTransform}(2 => 5, (16,); permuted = false)(), # 160 parameters
@@ -125,7 +124,7 @@ julia> OperatorKernel(2 => 5, (16,), FourierTransform)
125124end # Total: 175 parameters,
126125 # plus 1 states.
127126
128- julia> OperatorKernel(2 => 5, (16,), FourierTransform; permuted=Val(true))
127+ julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64} ; permuted=Val(true))
129128@compact(
130129 l₁ = Conv((1,), 2 => 5), # 15 parameters
131130 l₂ = OperatorConv{FourierTransform}(2 => 5, (16,); permuted = true)(), # 160 parameters
@@ -140,7 +139,7 @@ end # Total: 175 parameters,
140139"""
141140function OperatorKernel (ch:: Pair{<:Integer, <:Integer} , modes:: Dims{N} , transform:: Type{TR} ,
142141 act:: A = identity; allow_fast_activation:: Bool = false , permuted:: Val{perm} = Val (false ),
143- kwargs... ) where {N, TR <: AbstractTransform , perm, A}
142+ kwargs... ) where {N, TR <: AbstractTransform{<:Number} , perm, A}
144143 act = allow_fast_activation ? NNlib. fast_act (act) : act
145144 l₁ = perm ? Conv (map (_ -> 1 , modes), ch) : Dense (ch)
146145 l₂ = OperatorConv (ch, modes, transform; permuted, kwargs... )
155154"""
156155 SpectralKernel(args...; kwargs...)
157156
158- Construct a `OperatorKernel` with `FourierTransform` as the transform. See
157+ Construct a `OperatorKernel` with `FourierTransform{ComplexF32} ` as the transform. See
159158[`OperatorKernel`](@ref) for the individual arguments.
160159
161160## Example
@@ -188,5 +187,5 @@ end # Total: 175 parameters,
188187"""
189188function SpectralKernel (ch:: Pair{<:Integer, <:Integer} , modes:: Dims{N} ,
190189 act:: A = identity; kwargs... ) where {N, A}
191- return OperatorKernel (ch, modes, FourierTransform, act; kwargs... )
190+ return OperatorKernel (ch, modes, FourierTransform{ComplexF32} , act; kwargs... )
192191end
0 commit comments