|
41 | 41 | julia> SpectralConv(2=>5, (16, ))
|
42 | 42 | SpectralConv(2 => 5, (16,), σ=identity)
|
43 | 43 |
|
| 44 | +julia> using Flux |
| 45 | +
|
44 | 46 | julia> SpectralConv(2=>5, (16, ), relu)
|
45 | 47 | SpectralConv(2 => 5, (16,), σ=relu)
|
46 | 48 |
|
47 | 49 | julia> SpectralConv(2=>5, (16, ), relu, permuted=true)
|
48 | 50 | SpectralConvPerm(2 => 5, (16,), σ=relu)
|
49 |
| -```` |
| 51 | +``` |
50 | 52 | """
|
51 | 53 | function SpectralConv(
|
52 | 54 | ch::Pair{S, S},
|
@@ -104,17 +106,70 @@ end
|
104 | 106 | # operator #
|
105 | 107 | ############
|
106 | 108 |
|
| 109 | +""" |
| 110 | + FourierOperator(ch, modes, σ=identity; permuted=false) |
| 111 | +
|
| 112 | +## FourierOperator |
| 113 | +
|
| 114 | +* ``v(x)``: input |
| 115 | +* ``F``, ``F^{-1}``: Fourier transform, inverse fourier transform |
| 116 | +* ``L``: linear transform on the lower Fouier modes |
| 117 | +* ``D``: local linear transform |
| 118 | +
|
| 119 | +``` |
| 120 | + ┌ F -> L -> F¯¹ ┐ |
| 121 | +v(x) -> ┤ ├ -> + -> σ |
| 122 | + └ D ┘ |
| 123 | +``` |
| 124 | +
|
| 125 | +## Example |
| 126 | +
|
| 127 | +```jldoctest |
| 128 | +julia> FourierOperator(2=>5, (16, )) |
| 129 | +Chain( |
| 130 | + Parallel( |
| 131 | + +, |
| 132 | + Dense(2, 5), # 15 parameters |
| 133 | + SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters |
| 134 | + ), |
| 135 | + NeuralOperators.var"#activation_func#14"{typeof(identity)}(identity), |
| 136 | +) # Total: 3 arrays, 175 parameters, 1.668 KiB. |
| 137 | +
|
| 138 | +julia> using Flux |
| 139 | +
|
| 140 | +julia> FourierOperator(2=>5, (16, ), relu) |
| 141 | +Chain( |
| 142 | + Parallel( |
| 143 | + +, |
| 144 | + Dense(2, 5), # 15 parameters |
| 145 | + SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters |
| 146 | + ), |
| 147 | + NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu), |
| 148 | +) # Total: 3 arrays, 175 parameters, 1.668 KiB. |
| 149 | +
|
| 150 | +julia> FourierOperator(2=>5, (16, ), relu, permuted=true) |
| 151 | +Chain( |
| 152 | + Parallel( |
| 153 | + +, |
| 154 | + Conv((1,), 2 => 5), # 15 parameters |
| 155 | + SpectralConvPerm(2 => 5, (16,), σ=identity), # 160 parameters |
| 156 | + ), |
| 157 | + NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu), |
| 158 | +) # Total: 3 arrays, 175 parameters, 1.871 KiB. |
| 159 | +``` |
| 160 | +""" |
107 | 161 | function FourierOperator(
|
108 | 162 | ch::Pair{S, S},
|
109 | 163 | modes::NTuple{N, S},
|
110 | 164 | σ=identity;
|
111 | 165 | permuted=false
|
112 | 166 | ) where {S<:Integer, N}
|
113 | 167 | short_cut = permuted ? Conv(Tuple(ones(Int, length(modes))), ch) : Dense(ch.first, ch.second)
|
| 168 | + activation_func(x) = σ.(x) |
114 | 169 |
|
115 | 170 | return Chain(
|
116 | 171 | Parallel(+, short_cut, SpectralConv(ch, modes, permuted=permuted)),
|
117 |
| - x -> σ.(x) |
| 172 | + activation_func |
118 | 173 | )
|
119 | 174 | end
|
120 | 175 |
|
|
0 commit comments