|
2 | 2 | FourierNeuralOperator,
|
3 | 3 | MarkovNeuralOperator
|
4 | 4 |
|
| 5 | +struct FourierNeuralOperator{L, K, P} <: AbstractOperatorModel |
| 6 | + lifting_net::L |
| 7 | + integral_kernel_net::K |
| 8 | + project_net::P |
| 9 | +end |
| 10 | + |
| 11 | +Flux.@functor FourierNeuralOperator |
| 12 | + |
5 | 13 | """
|
6 | 14 | FourierNeuralOperator(;
|
7 | 15 | ch = (2, 64, 64, 64, 64, 64, 128, 1),
|
@@ -85,16 +93,31 @@ function FourierNeuralOperator(;
|
85 | 93 | modes = (16,),
|
86 | 94 | σ = gelu)
|
87 | 95 | Transform = FourierTransform
|
| 96 | + lifting = Dense(ch[1], ch[2]) |
| 97 | + mapping = Chain(OperatorKernel(ch[2] => ch[3], modes, Transform, σ), |
| 98 | + OperatorKernel(ch[3] => ch[4], modes, Transform, σ), |
| 99 | + OperatorKernel(ch[4] => ch[5], modes, Transform, σ), |
| 100 | + OperatorKernel(ch[5] => ch[6], modes, Transform)) |
| 101 | + project = Chain(Dense(ch[6], ch[7], σ), |
| 102 | + Dense(ch[7], ch[8])) |
| 103 | + |
| 104 | + return FourierNeuralOperator(lifting, mapping, project) |
| 105 | +end |
| 106 | + |
| 107 | +function (fno::FourierNeuralOperator)(𝐱::AbstractArray) |
| 108 | + lifted = fno.lifting_net(𝐱) |
| 109 | + mapped = fno.integral_kernel_net(lifted) |
| 110 | + 𝐲 = fno.project_net(mapped) |
88 | 111 |
|
89 |
| - return Chain(Dense(ch[1], ch[2]), |
90 |
| - OperatorKernel(ch[2] => ch[3], modes, Transform, σ), |
91 |
| - OperatorKernel(ch[3] => ch[4], modes, Transform, σ), |
92 |
| - OperatorKernel(ch[4] => ch[5], modes, Transform, σ), |
93 |
| - OperatorKernel(ch[5] => ch[6], modes, Transform), |
94 |
| - Dense(ch[6], ch[7], σ), |
95 |
| - Dense(ch[7], ch[8])) |
| 112 | + return 𝐲 |
96 | 113 | end
|
97 | 114 |
|
| 115 | +struct MarkovNeuralOperator{F} <: AbstractOperatorModel |
| 116 | + fno::F |
| 117 | +end |
| 118 | + |
| 119 | +Flux.@functor MarkovNeuralOperator |
| 120 | + |
98 | 121 | """
|
99 | 122 | MarkovNeuralOperator(;
|
100 | 123 | ch = (1, 64, 64, 64, 64, 64, 1),
|
@@ -176,11 +199,15 @@ function MarkovNeuralOperator(;
|
176 | 199 | modes = (24, 24),
|
177 | 200 | σ = gelu)
|
178 | 201 | Transform = FourierTransform
|
179 |
| - |
180 |
| - return Chain(Dense(ch[1], ch[2]), |
181 |
| - OperatorKernel(ch[2] => ch[3], modes, Transform, σ), |
182 |
| - OperatorKernel(ch[3] => ch[4], modes, Transform, σ), |
183 |
| - OperatorKernel(ch[4] => ch[5], modes, Transform, σ), |
184 |
| - OperatorKernel(ch[5] => ch[6], modes, Transform, σ), |
185 |
| - Dense(ch[6], ch[7])) |
| 202 | + lifting = Dense(ch[1], ch[2]) |
| 203 | + mapping = Chain(OperatorKernel(ch[2] => ch[3], modes, Transform, σ), |
| 204 | + OperatorKernel(ch[3] => ch[4], modes, Transform, σ), |
| 205 | + OperatorKernel(ch[4] => ch[5], modes, Transform, σ), |
| 206 | + OperatorKernel(ch[5] => ch[6], modes, Transform, σ)) |
| 207 | + project = Dense(ch[6], ch[7]) |
| 208 | + fno = FourierNeuralOperator(lifting, mapping, project) |
| 209 | + |
| 210 | + return MarkovNeuralOperator(fno) |
186 | 211 | end
|
| 212 | + |
| 213 | +(mno::MarkovNeuralOperator)(𝐱::AbstractArray) = mno.fno(𝐱) |
0 commit comments