Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 54602e6

Browse files
authored
Merge pull request #82 from yuehhua/abstract
Add abstract types for models and kernels
2 parents 63f0c58 + 77c0f6e commit 54602e6

File tree

6 files changed

+52
-17
lines changed

6 files changed

+52
-17
lines changed

src/DeepONet/DeepONet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ branch net: (Chain(Dense(2, 128), Dense(128, 64), Dense(64, 72)))
8181
Trunk net: (Chain(Dense(1, 24), Dense(24, 72)))
8282
```
8383
"""
84-
struct DeepONet{T1, T2}
84+
struct DeepONet{T1, T2} <: AbstractOperatorModel
8585
branch_net::T1
8686
trunk_net::T2
8787
end

src/FNO/FNO.jl

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@ export
22
FourierNeuralOperator,
33
MarkovNeuralOperator
44

5+
struct FourierNeuralOperator{L, K, P} <: AbstractOperatorModel
6+
lifting_net::L
7+
integral_kernel_net::K
8+
project_net::P
9+
end
10+
11+
Flux.@functor FourierNeuralOperator
12+
513
"""
614
FourierNeuralOperator(;
715
ch = (2, 64, 64, 64, 64, 64, 128, 1),
@@ -85,16 +93,31 @@ function FourierNeuralOperator(;
8593
modes = (16,),
8694
σ = gelu)
8795
Transform = FourierTransform
96+
lifting = Dense(ch[1], ch[2])
97+
mapping = Chain(OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
98+
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
99+
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
100+
OperatorKernel(ch[5] => ch[6], modes, Transform))
101+
project = Chain(Dense(ch[6], ch[7], σ),
102+
Dense(ch[7], ch[8]))
103+
104+
return FourierNeuralOperator(lifting, mapping, project)
105+
end
106+
107+
function (fno::FourierNeuralOperator)(𝐱::AbstractArray)
108+
lifted = fno.lifting_net(𝐱)
109+
mapped = fno.integral_kernel_net(lifted)
110+
𝐲 = fno.project_net(mapped)
88111

89-
return Chain(Dense(ch[1], ch[2]),
90-
OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
91-
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
92-
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
93-
OperatorKernel(ch[5] => ch[6], modes, Transform),
94-
Dense(ch[6], ch[7], σ),
95-
Dense(ch[7], ch[8]))
112+
return 𝐲
96113
end
97114

115+
struct MarkovNeuralOperator{F} <: AbstractOperatorModel
116+
fno::F
117+
end
118+
119+
Flux.@functor MarkovNeuralOperator
120+
98121
"""
99122
MarkovNeuralOperator(;
100123
ch = (1, 64, 64, 64, 64, 64, 1),
@@ -176,11 +199,15 @@ function MarkovNeuralOperator(;
176199
modes = (24, 24),
177200
σ = gelu)
178201
Transform = FourierTransform
179-
180-
return Chain(Dense(ch[1], ch[2]),
181-
OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
182-
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
183-
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
184-
OperatorKernel(ch[5] => ch[6], modes, Transform, σ),
185-
Dense(ch[6], ch[7]))
202+
lifting = Dense(ch[1], ch[2])
203+
mapping = Chain(OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
204+
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
205+
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
206+
OperatorKernel(ch[5] => ch[6], modes, Transform, σ))
207+
project = Dense(ch[6], ch[7])
208+
fno = FourierNeuralOperator(lifting, mapping, project)
209+
210+
return MarkovNeuralOperator(fno)
186211
end
212+
213+
(mno::MarkovNeuralOperator)(𝐱::AbstractArray) = mno.fno(𝐱)

src/NOMAD/NOMAD.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export NOMAD
22

3-
struct NOMAD{T1, T2}
3+
struct NOMAD{T1, T2} <: AbstractOperatorModel
44
approximator_net::T1
55
decoder_net::T2
66
end

src/NeuralOperators.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ using ChainRulesCore
1111
using GeometricFlux
1212
using Statistics
1313

14+
include("abstracttypes.jl")
15+
1416
# kernels
1517
include("Transform/Transform.jl")
1618
include("operator_kernel.jl")

src/abstracttypes.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
export
2+
AbstractOperatorModel,
3+
AbstractOperatorKernel
4+
5+
abstract type AbstractOperatorModel end
6+
abstract type AbstractOperatorKernel end

src/operator_kernel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ end
113113
# operator #
114114
############
115115

116-
struct OperatorKernel{L, C, F}
116+
struct OperatorKernel{L, C, F} <: AbstractOperatorKernel
117117
linear::L
118118
conv::C
119119
σ::F

0 commit comments

Comments
 (0)