Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 77c0f6e

Browse files
committed
add type for FNO and MNO
1 parent 2344a1c commit 77c0f6e

File tree

1 file changed

+41
-14
lines changed

1 file changed

+41
-14
lines changed

src/FNO/FNO.jl

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@ export
22
FourierNeuralOperator,
33
MarkovNeuralOperator
44

5+
struct FourierNeuralOperator{L, K, P} <: AbstractOperatorModel
6+
lifting_net::L
7+
integral_kernel_net::K
8+
project_net::P
9+
end
10+
11+
Flux.@functor FourierNeuralOperator
12+
513
"""
614
FourierNeuralOperator(;
715
ch = (2, 64, 64, 64, 64, 64, 128, 1),
@@ -85,16 +93,31 @@ function FourierNeuralOperator(;
8593
modes = (16,),
8694
σ = gelu)
8795
Transform = FourierTransform
96+
lifting = Dense(ch[1], ch[2])
97+
mapping = Chain(OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
98+
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
99+
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
100+
OperatorKernel(ch[5] => ch[6], modes, Transform))
101+
project = Chain(Dense(ch[6], ch[7], σ),
102+
Dense(ch[7], ch[8]))
103+
104+
return FourierNeuralOperator(lifting, mapping, project)
105+
end
106+
107+
function (fno::FourierNeuralOperator)(𝐱::AbstractArray)
108+
lifted = fno.lifting_net(𝐱)
109+
mapped = fno.integral_kernel_net(lifted)
110+
𝐲 = fno.project_net(mapped)
88111

89-
return Chain(Dense(ch[1], ch[2]),
90-
OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
91-
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
92-
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
93-
OperatorKernel(ch[5] => ch[6], modes, Transform),
94-
Dense(ch[6], ch[7], σ),
95-
Dense(ch[7], ch[8]))
112+
return 𝐲
96113
end
97114

115+
struct MarkovNeuralOperator{F} <: AbstractOperatorModel
116+
fno::F
117+
end
118+
119+
Flux.@functor MarkovNeuralOperator
120+
98121
"""
99122
MarkovNeuralOperator(;
100123
ch = (1, 64, 64, 64, 64, 64, 1),
@@ -176,11 +199,15 @@ function MarkovNeuralOperator(;
176199
modes = (24, 24),
177200
σ = gelu)
178201
Transform = FourierTransform
179-
180-
return Chain(Dense(ch[1], ch[2]),
181-
OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
182-
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
183-
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
184-
OperatorKernel(ch[5] => ch[6], modes, Transform, σ),
185-
Dense(ch[6], ch[7]))
202+
lifting = Dense(ch[1], ch[2])
203+
mapping = Chain(OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
204+
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
205+
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
206+
OperatorKernel(ch[5] => ch[6], modes, Transform, σ))
207+
project = Dense(ch[6], ch[7])
208+
fno = FourierNeuralOperator(lifting, mapping, project)
209+
210+
return MarkovNeuralOperator(fno)
186211
end
212+
213+
(mno::MarkovNeuralOperator)(𝐱::AbstractArray) = mno.fno(𝐱)

0 commit comments

Comments
 (0)